Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                          License Agreement
     11 //                For Open Source Computer Vision Library
     12 //
     13 // Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
     14 // Copyright (C) 2009, Willow Garage Inc., all rights reserved.
     15 // Copyright (C) 2013, OpenCV Foundation, all rights reserved.
     16 // Copyright (C) 2014, Itseez Inc, all rights reserved.
     17 // Third party copyrights are property of their respective owners.
     18 //
     19 // Redistribution and use in source and binary forms, with or without modification,
     20 // are permitted provided that the following conditions are met:
     21 //
     22 //   * Redistribution's of source code must retain the above copyright notice,
     23 //     this list of conditions and the following disclaimer.
     24 //
     25 //   * Redistribution's in binary form must reproduce the above copyright notice,
     26 //     this list of conditions and the following disclaimer in the documentation
     27 //     and/or other materials provided with the distribution.
     28 //
     29 //   * The name of the copyright holders may not be used to endorse or promote products
     30 //     derived from this software without specific prior written permission.
     31 //
     32 // This software is provided by the copyright holders and contributors "as is" and
     33 // any express or implied warranties, including, but not limited to, the implied
     34 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     35 // In no event shall the Intel Corporation or contributors be liable for any direct,
     36 // indirect, incidental, special, exemplary, or consequential damages
     37 // (including, but not limited to, procurement of substitute goods or services;
     38 // loss of use, data, or profits; or business interruption) however caused
     39 // and on any theory of liability, whether in contract, strict liability,
     40 // or tort (including negligence or otherwise) arising in any way out of
     41 // the use of this software, even if advised of the possibility of such damage.
     42 //
     43 //M*/
     44 
     45 #include "precomp.hpp"
     46 #include "kdtree.hpp"
     47 
     48 namespace cv
     49 {
     50 namespace ml
     51 {
     52 // This is reimplementation of kd-trees from cvkdtree*.* by Xavier Delacour, cleaned-up and
     53 // adopted to work with the new OpenCV data structures.
     54 
     55 // The algorithm is taken from:
     56 // J.S. Beis and D.G. Lowe. Shape indexing using approximate nearest-neighbor search
     57 // in highdimensional spaces. In Proc. IEEE Conf. Comp. Vision Patt. Recog.,
     58 // pages 1000--1006, 1997. http://citeseer.ist.psu.edu/beis97shape.html
     59 
     60 const int MAX_TREE_DEPTH = 32;
     61 
     62 KDTree::KDTree()
     63 {
     64     maxDepth = -1;
     65     normType = NORM_L2;
     66 }
     67 
     68 KDTree::KDTree(InputArray _points, bool _copyData)
     69 {
     70     maxDepth = -1;
     71     normType = NORM_L2;
     72     build(_points, _copyData);
     73 }
     74 
     75 KDTree::KDTree(InputArray _points, InputArray _labels, bool _copyData)
     76 {
     77     maxDepth = -1;
     78     normType = NORM_L2;
     79     build(_points, _labels, _copyData);
     80 }
     81 
     82 struct SubTree
     83 {
     84     SubTree() : first(0), last(0), nodeIdx(0), depth(0) {}
     85     SubTree(int _first, int _last, int _nodeIdx, int _depth)
     86         : first(_first), last(_last), nodeIdx(_nodeIdx), depth(_depth) {}
     87     int first;
     88     int last;
     89     int nodeIdx;
     90     int depth;
     91 };
     92 
     93 
     94 static float
     95 medianPartition( size_t* ofs, int a, int b, const float* vals )
     96 {
     97     int k, a0 = a, b0 = b;
     98     int middle = (a + b)/2;
     99     while( b > a )
    100     {
    101         int i0 = a, i1 = (a+b)/2, i2 = b;
    102         float v0 = vals[ofs[i0]], v1 = vals[ofs[i1]], v2 = vals[ofs[i2]];
    103         int ip = v0 < v1 ? (v1 < v2 ? i1 : v0 < v2 ? i2 : i0) :
    104             v0 < v2 ? i0 : (v1 < v2 ? i2 : i1);
    105         float pivot = vals[ofs[ip]];
    106         std::swap(ofs[ip], ofs[i2]);
    107 
    108         for( i1 = i0, i0--; i1 <= i2; i1++ )
    109             if( vals[ofs[i1]] <= pivot )
    110             {
    111                 i0++;
    112                 std::swap(ofs[i0], ofs[i1]);
    113             }
    114         if( i0 == middle )
    115             break;
    116         if( i0 > middle )
    117             b = i0 - (b == i0);
    118         else
    119             a = i0;
    120     }
    121 
    122     float pivot = vals[ofs[middle]];
    123     int less = 0, more = 0;
    124     for( k = a0; k < middle; k++ )
    125     {
    126         CV_Assert(vals[ofs[k]] <= pivot);
    127         less += vals[ofs[k]] < pivot;
    128     }
    129     for( k = b0; k > middle; k-- )
    130     {
    131         CV_Assert(vals[ofs[k]] >= pivot);
    132         more += vals[ofs[k]] > pivot;
    133     }
    134     CV_Assert(std::abs(more - less) <= 1);
    135 
    136     return vals[ofs[middle]];
    137 }
    138 
    139 static void
    140 computeSums( const Mat& points, const size_t* ofs, int a, int b, double* sums )
    141 {
    142     int i, j, dims = points.cols;
    143     const float* data = points.ptr<float>(0);
    144     for( j = 0; j < dims; j++ )
    145         sums[j*2] = sums[j*2+1] = 0;
    146     for( i = a; i <= b; i++ )
    147     {
    148         const float* row = data + ofs[i];
    149         for( j = 0; j < dims; j++ )
    150         {
    151             double t = row[j], s = sums[j*2] + t, s2 = sums[j*2+1] + t*t;
    152             sums[j*2] = s; sums[j*2+1] = s2;
    153         }
    154     }
    155 }
    156 
    157 
    158 void KDTree::build(InputArray _points, bool _copyData)
    159 {
    160     build(_points, noArray(), _copyData);
    161 }
    162 
    163 
    164 void KDTree::build(InputArray __points, InputArray __labels, bool _copyData)
    165 {
    166     Mat _points = __points.getMat(), _labels = __labels.getMat();
    167     CV_Assert(_points.type() == CV_32F && !_points.empty());
    168     std::vector<KDTree::Node>().swap(nodes);
    169 
    170     if( !_copyData )
    171         points = _points;
    172     else
    173     {
    174         points.release();
    175         points.create(_points.size(), _points.type());
    176     }
    177 
    178     int i, j, n = _points.rows, ptdims = _points.cols, top = 0;
    179     const float* data = _points.ptr<float>(0);
    180     float* dstdata = points.ptr<float>(0);
    181     size_t step = _points.step1();
    182     size_t dstep = points.step1();
    183     int ptpos = 0;
    184     labels.resize(n);
    185     const int* _labels_data = 0;
    186 
    187     if( !_labels.empty() )
    188     {
    189         int nlabels = _labels.checkVector(1, CV_32S, true);
    190         CV_Assert(nlabels == n);
    191         _labels_data = _labels.ptr<int>();
    192     }
    193 
    194     Mat sumstack(MAX_TREE_DEPTH*2, ptdims*2, CV_64F);
    195     SubTree stack[MAX_TREE_DEPTH*2];
    196 
    197     std::vector<size_t> _ptofs(n);
    198     size_t* ptofs = &_ptofs[0];
    199 
    200     for( i = 0; i < n; i++ )
    201         ptofs[i] = i*step;
    202 
    203     nodes.push_back(Node());
    204     computeSums(points, ptofs, 0, n-1, sumstack.ptr<double>(top));
    205     stack[top++] = SubTree(0, n-1, 0, 0);
    206     int _maxDepth = 0;
    207 
    208     while( --top >= 0 )
    209     {
    210         int first = stack[top].first, last = stack[top].last;
    211         int depth = stack[top].depth, nidx = stack[top].nodeIdx;
    212         int count = last - first + 1, dim = -1;
    213         const double* sums = sumstack.ptr<double>(top);
    214         double invCount = 1./count, maxVar = -1.;
    215 
    216         if( count == 1 )
    217         {
    218             int idx0 = (int)(ptofs[first]/step);
    219             int idx = _copyData ? ptpos++ : idx0;
    220             nodes[nidx].idx = ~idx;
    221             if( _copyData )
    222             {
    223                 const float* src = data + ptofs[first];
    224                 float* dst = dstdata + idx*dstep;
    225                 for( j = 0; j < ptdims; j++ )
    226                     dst[j] = src[j];
    227             }
    228             labels[idx] = _labels_data ? _labels_data[idx0] : idx0;
    229             _maxDepth = std::max(_maxDepth, depth);
    230             continue;
    231         }
    232 
    233         // find the dimensionality with the biggest variance
    234         for( j = 0; j < ptdims; j++ )
    235         {
    236             double m = sums[j*2]*invCount;
    237             double varj = sums[j*2+1]*invCount - m*m;
    238             if( maxVar < varj )
    239             {
    240                 maxVar = varj;
    241                 dim = j;
    242             }
    243         }
    244 
    245         int left = (int)nodes.size(), right = left + 1;
    246         nodes.push_back(Node());
    247         nodes.push_back(Node());
    248         nodes[nidx].idx = dim;
    249         nodes[nidx].left = left;
    250         nodes[nidx].right = right;
    251         nodes[nidx].boundary = medianPartition(ptofs, first, last, data + dim);
    252 
    253         int middle = (first + last)/2;
    254         double *lsums = (double*)sums, *rsums = lsums + ptdims*2;
    255         computeSums(points, ptofs, middle+1, last, rsums);
    256         for( j = 0; j < ptdims*2; j++ )
    257             lsums[j] = sums[j] - rsums[j];
    258         stack[top++] = SubTree(first, middle, left, depth+1);
    259         stack[top++] = SubTree(middle+1, last, right, depth+1);
    260     }
    261     maxDepth = _maxDepth;
    262 }
    263 
    264 
    265 struct PQueueElem
    266 {
    267     PQueueElem() : dist(0), idx(0) {}
    268     PQueueElem(float _dist, int _idx) : dist(_dist), idx(_idx) {}
    269     float dist;
    270     int idx;
    271 };
    272 
    273 
    274 int KDTree::findNearest(InputArray _vec, int K, int emax,
    275                         OutputArray _neighborsIdx, OutputArray _neighbors,
    276                         OutputArray _dist, OutputArray _labels) const
    277 
    278 {
    279     Mat vecmat = _vec.getMat();
    280     CV_Assert( vecmat.isContinuous() && vecmat.type() == CV_32F && vecmat.total() == (size_t)points.cols );
    281     const float* vec = vecmat.ptr<float>();
    282     K = std::min(K, points.rows);
    283     int ptdims = points.cols;
    284 
    285     CV_Assert(K > 0 && (normType == NORM_L2 || normType == NORM_L1));
    286 
    287     AutoBuffer<uchar> _buf((K+1)*(sizeof(float) + sizeof(int)));
    288     int* idx = (int*)(uchar*)_buf;
    289     float* dist = (float*)(idx + K + 1);
    290     int i, j, ncount = 0, e = 0;
    291 
    292     int qsize = 0, maxqsize = 1 << 10;
    293     AutoBuffer<uchar> _pqueue(maxqsize*sizeof(PQueueElem));
    294     PQueueElem* pqueue = (PQueueElem*)(uchar*)_pqueue;
    295     emax = std::max(emax, 1);
    296 
    297     for( e = 0; e < emax; )
    298     {
    299         float d, alt_d = 0.f;
    300         int nidx;
    301 
    302         if( e == 0 )
    303             nidx = 0;
    304         else
    305         {
    306             // take the next node from the priority queue
    307             if( qsize == 0 )
    308                 break;
    309             nidx = pqueue[0].idx;
    310             alt_d = pqueue[0].dist;
    311             if( --qsize > 0 )
    312             {
    313                 std::swap(pqueue[0], pqueue[qsize]);
    314                 d = pqueue[0].dist;
    315                 for( i = 0;;)
    316                 {
    317                     int left = i*2 + 1, right = i*2 + 2;
    318                     if( left >= qsize )
    319                         break;
    320                     if( right < qsize && pqueue[right].dist < pqueue[left].dist )
    321                         left = right;
    322                     if( pqueue[left].dist >= d )
    323                         break;
    324                     std::swap(pqueue[i], pqueue[left]);
    325                     i = left;
    326                 }
    327             }
    328 
    329             if( ncount == K && alt_d > dist[ncount-1] )
    330                 continue;
    331         }
    332 
    333         for(;;)
    334         {
    335             if( nidx < 0 )
    336                 break;
    337             const Node& n = nodes[nidx];
    338 
    339             if( n.idx < 0 )
    340             {
    341                 i = ~n.idx;
    342                 const float* row = points.ptr<float>(i);
    343                 if( normType == NORM_L2 )
    344                     for( j = 0, d = 0.f; j < ptdims; j++ )
    345                     {
    346                         float t = vec[j] - row[j];
    347                         d += t*t;
    348                     }
    349                 else
    350                     for( j = 0, d = 0.f; j < ptdims; j++ )
    351                         d += std::abs(vec[j] - row[j]);
    352 
    353                 dist[ncount] = d;
    354                 idx[ncount] = i;
    355                 for( i = ncount-1; i >= 0; i-- )
    356                 {
    357                     if( dist[i] <= d )
    358                         break;
    359                     std::swap(dist[i], dist[i+1]);
    360                     std::swap(idx[i], idx[i+1]);
    361                 }
    362                 ncount += ncount < K;
    363                 e++;
    364                 break;
    365             }
    366 
    367             int alt;
    368             if( vec[n.idx] <= n.boundary )
    369             {
    370                 nidx = n.left;
    371                 alt = n.right;
    372             }
    373             else
    374             {
    375                 nidx = n.right;
    376                 alt = n.left;
    377             }
    378 
    379             d = vec[n.idx] - n.boundary;
    380             if( normType == NORM_L2 )
    381                 d = d*d + alt_d;
    382             else
    383                 d = std::abs(d) + alt_d;
    384             // subtree prunning
    385             if( ncount == K && d > dist[ncount-1] )
    386                 continue;
    387             // add alternative subtree to the priority queue
    388             pqueue[qsize] = PQueueElem(d, alt);
    389             for( i = qsize; i > 0; )
    390             {
    391                 int parent = (i-1)/2;
    392                 if( parent < 0 || pqueue[parent].dist <= d )
    393                     break;
    394                 std::swap(pqueue[i], pqueue[parent]);
    395                 i = parent;
    396             }
    397             qsize += qsize+1 < maxqsize;
    398         }
    399     }
    400 
    401     K = std::min(K, ncount);
    402     if( _neighborsIdx.needed() )
    403     {
    404         _neighborsIdx.create(K, 1, CV_32S, -1, true);
    405         Mat nidx = _neighborsIdx.getMat();
    406         Mat(nidx.size(), CV_32S, &idx[0]).copyTo(nidx);
    407     }
    408     if( _dist.needed() )
    409         sqrt(Mat(K, 1, CV_32F, dist), _dist);
    410 
    411     if( _neighbors.needed() || _labels.needed() )
    412         getPoints(Mat(K, 1, CV_32S, idx), _neighbors, _labels);
    413     return K;
    414 }
    415 
    416 
    417 void KDTree::findOrthoRange(InputArray _lowerBound,
    418                             InputArray _upperBound,
    419                             OutputArray _neighborsIdx,
    420                             OutputArray _neighbors,
    421                             OutputArray _labels ) const
    422 {
    423     int ptdims = points.cols;
    424     Mat lowerBound = _lowerBound.getMat(), upperBound = _upperBound.getMat();
    425     CV_Assert( lowerBound.size == upperBound.size &&
    426                lowerBound.isContinuous() &&
    427                upperBound.isContinuous() &&
    428                lowerBound.type() == upperBound.type() &&
    429                lowerBound.type() == CV_32F &&
    430                lowerBound.total() == (size_t)ptdims );
    431     const float* L = lowerBound.ptr<float>();
    432     const float* R = upperBound.ptr<float>();
    433 
    434     std::vector<int> idx;
    435     AutoBuffer<int> _stack(MAX_TREE_DEPTH*2 + 1);
    436     int* stack = _stack;
    437     int top = 0;
    438 
    439     stack[top++] = 0;
    440 
    441     while( --top >= 0 )
    442     {
    443         int nidx = stack[top];
    444         if( nidx < 0 )
    445             break;
    446         const Node& n = nodes[nidx];
    447         if( n.idx < 0 )
    448         {
    449             int j, i = ~n.idx;
    450             const float* row = points.ptr<float>(i);
    451             for( j = 0; j < ptdims; j++ )
    452                 if( row[j] < L[j] || row[j] >= R[j] )
    453                     break;
    454             if( j == ptdims )
    455                 idx.push_back(i);
    456             continue;
    457         }
    458         if( L[n.idx] <= n.boundary )
    459             stack[top++] = n.left;
    460         if( R[n.idx] > n.boundary )
    461             stack[top++] = n.right;
    462     }
    463 
    464     if( _neighborsIdx.needed() )
    465     {
    466         _neighborsIdx.create((int)idx.size(), 1, CV_32S, -1, true);
    467         Mat nidx = _neighborsIdx.getMat();
    468         Mat(nidx.size(), CV_32S, &idx[0]).copyTo(nidx);
    469     }
    470     getPoints( idx, _neighbors, _labels );
    471 }
    472 
    473 
    474 void KDTree::getPoints(InputArray _idx, OutputArray _pts, OutputArray _labels) const
    475 {
    476     Mat idxmat = _idx.getMat(), pts, labelsmat;
    477     CV_Assert( idxmat.isContinuous() && idxmat.type() == CV_32S &&
    478                (idxmat.cols == 1 || idxmat.rows == 1) );
    479     const int* idx = idxmat.ptr<int>();
    480     int* dstlabels = 0;
    481 
    482     int ptdims = points.cols;
    483     int i, nidx = (int)idxmat.total();
    484     if( nidx == 0 )
    485     {
    486         _pts.release();
    487         _labels.release();
    488         return;
    489     }
    490 
    491     if( _pts.needed() )
    492     {
    493         _pts.create( nidx, ptdims, points.type());
    494         pts = _pts.getMat();
    495     }
    496 
    497     if(_labels.needed())
    498     {
    499         _labels.create(nidx, 1, CV_32S, -1, true);
    500         labelsmat = _labels.getMat();
    501         CV_Assert( labelsmat.isContinuous() );
    502         dstlabels = labelsmat.ptr<int>();
    503     }
    504     const int* srclabels = !labels.empty() ? &labels[0] : 0;
    505 
    506     for( i = 0; i < nidx; i++ )
    507     {
    508         int k = idx[i];
    509         CV_Assert( (unsigned)k < (unsigned)points.rows );
    510         const float* src = points.ptr<float>(k);
    511         if( !pts.empty() )
    512             std::copy(src, src + ptdims, pts.ptr<float>(i));
    513         if( dstlabels )
    514             dstlabels[i] = srclabels ? srclabels[k] : k;
    515     }
    516 }
    517 
    518 
    519 const float* KDTree::getPoint(int ptidx, int* label) const
    520 {
    521     CV_Assert( (unsigned)ptidx < (unsigned)points.rows);
    522     if(label)
    523         *label = labels[ptidx];
    524     return points.ptr<float>(ptidx);
    525 }
    526 
    527 
    528 int KDTree::dims() const
    529 {
    530     return !points.empty() ? points.cols : 0;
    531 }
    532 
    533 }
    534 }
    535