Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //                For Open Source Computer Vision Library
     12 //
     13 // Copyright (C) 2008, Xavier Delacour, all rights reserved.
     14 // Third party copyrights are property of their respective owners.
     15 //
     16 // Redistribution and use in source and binary forms, with or without modification,
     17 // are permitted provided that the following conditions are met:
     18 //
     19 //   * Redistribution's of source code must retain the above copyright notice,
     20 //     this list of conditions and the following disclaimer.
     21 //
     22 //   * Redistribution's in binary form must reproduce the above copyright notice,
     23 //     this list of conditions and the following disclaimer in the documentation
     24 //     and/or other materials provided with the distribution.
     25 //
     26 //   * The name of Intel Corporation may not be used to endorse or promote products
     27 //     derived from this software without specific prior written permission.
     28 //
     29 // This software is provided by the copyright holders and contributors "as is" and
     30 // any express or implied warranties, including, but not limited to, the implied
     31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     32 // In no event shall the Intel Corporation or contributors be liable for any direct,
     33 // indirect, incidental, special, exemplary, or consequential damages
     34 // (including, but not limited to, procurement of substitute goods or services;
     35 // loss of use, data, or profits; or business interruption) however caused
     36 // and on any theory of liability, whether in contract, strict liability,
     37 // or tort (including negligence or otherwise) arising in any way out of
     38 // the use of this software, even if advised of the possibility of such damage.
     39 //
     40 //M*/
     41 
     42 // 2008-05-13, Xavier Delacour <xavier.delacour (at) gmail.com>
     43 
     44 #ifndef __cv_kdtree_h__
     45 #define __cv_kdtree_h__
     46 
     47 #include "_cv.h"
     48 
     49 #include <vector>
     50 #include <algorithm>
     51 #include <limits>
     52 #include <iostream>
     53 #include "assert.h"
     54 #include "math.h"
     55 
     56 // J.S. Beis and D.G. Lowe. Shape indexing using approximate nearest-neighbor search in highdimensional spaces. In Proc. IEEE Conf. Comp. Vision Patt. Recog., pages 1000--1006, 1997. http://citeseer.ist.psu.edu/beis97shape.html
     57 #undef __deref
     58 #undef __valuetype
     59 
     60 template < class __valuetype, class __deref >
     61 class CvKDTree {
     62 public:
     63   typedef __deref deref_type;
     64   typedef typename __deref::scalar_type scalar_type;
     65   typedef typename __deref::accum_type accum_type;
     66 
     67 private:
     68   struct node {
     69     int dim;			// split dimension; >=0 for nodes, -1 for leaves
     70     __valuetype value;		// if leaf, value of leaf
     71     int left, right;		// node indices of left and right branches
     72     scalar_type boundary;	// left if deref(value,dim)<=boundary, otherwise right
     73   };
     74   typedef std::vector < node > node_array;
     75 
     76   __deref deref;		// requires operator() (__valuetype lhs,int dim)
     77 
     78   node_array nodes;		// node storage
     79   int point_dim;		// dimension of points (the k in kd-tree)
     80   int root_node;		// index of root node, -1 if empty tree
     81 
     82   // for given set of point indices, compute dimension of highest variance
     83   template < class __instype, class __valuector >
     84   int dimension_of_highest_variance(__instype * first, __instype * last,
     85 				    __valuector ctor) {
     86     assert(last - first > 0);
     87 
     88     accum_type maxvar = -std::numeric_limits < accum_type >::max();
     89     int maxj = -1;
     90     for (int j = 0; j < point_dim; ++j) {
     91       accum_type mean = 0;
     92       for (__instype * k = first; k < last; ++k)
     93 	mean += deref(ctor(*k), j);
     94       mean /= last - first;
     95       accum_type var = 0;
     96       for (__instype * k = first; k < last; ++k) {
     97 	accum_type diff = accum_type(deref(ctor(*k), j)) - mean;
     98 	var += diff * diff;
     99       }
    100       var /= last - first;
    101 
    102       assert(maxj != -1 || var >= maxvar);
    103 
    104       if (var >= maxvar) {
    105 	maxvar = var;
    106 	maxj = j;
    107       }
    108     }
    109 
    110     return maxj;
    111   }
    112 
    113   // given point indices and dimension, find index of median; (almost) modifies [first,last)
    114   // such that points_in[first,median]<=point[median], points_in(median,last)>point[median].
    115   // implemented as partial quicksort; expected linear perf.
    116   template < class __instype, class __valuector >
    117   __instype * median_partition(__instype * first, __instype * last,
    118 			       int dim, __valuector ctor) {
    119     assert(last - first > 0);
    120     __instype *k = first + (last - first) / 2;
    121     median_partition(first, last, k, dim, ctor);
    122     return k;
    123   }
    124 
    125   template < class __instype, class __valuector >
    126   struct median_pr {
    127     const __instype & pivot;
    128     int dim;
    129     __deref deref;
    130     __valuector ctor;
    131     median_pr(const __instype & _pivot, int _dim, __deref _deref, __valuector _ctor)
    132       : pivot(_pivot), dim(_dim), deref(_deref), ctor(_ctor) {
    133     }
    134     bool operator() (const __instype & lhs) const {
    135       return deref(ctor(lhs), dim) <= deref(ctor(pivot), dim);
    136     }
    137   };
    138 
    139   template < class __instype, class __valuector >
    140   void median_partition(__instype * first, __instype * last,
    141 			__instype * k, int dim, __valuector ctor) {
    142     int pivot = (last - first) / 2;
    143 
    144     std::swap(first[pivot], last[-1]);
    145     __instype *middle = std::partition(first, last - 1,
    146 				       median_pr < __instype, __valuector >
    147 				       (last[-1], dim, deref, ctor));
    148     std::swap(*middle, last[-1]);
    149 
    150     if (middle < k)
    151       median_partition(middle + 1, last, k, dim, ctor);
    152     else if (middle > k)
    153       median_partition(first, middle, k, dim, ctor);
    154   }
    155 
    156   // insert given points into the tree; return created node
    157   template < class __instype, class __valuector >
    158   int insert(__instype * first, __instype * last, __valuector ctor) {
    159     if (first == last)
    160       return -1;
    161     else {
    162 
    163       int dim = dimension_of_highest_variance(first, last, ctor);
    164       __instype *median = median_partition(first, last, dim, ctor);
    165 
    166       __instype *split = median;
    167       for (; split != last && deref(ctor(*split), dim) ==
    168 	     deref(ctor(*median), dim); ++split);
    169 
    170       if (split == last) { // leaf
    171 	int nexti = -1;
    172 	for (--split; split >= first; --split) {
    173 	  int i = nodes.size();
    174 	  node & n = *nodes.insert(nodes.end(), node());
    175 	  n.dim = -1;
    176 	  n.value = ctor(*split);
    177 	  n.left = -1;
    178 	  n.right = nexti;
    179 	  nexti = i;
    180 	}
    181 
    182 	return nexti;
    183       } else { // node
    184 	int i = nodes.size();
    185 	// note that recursive insert may invalidate this ref
    186 	node & n = *nodes.insert(nodes.end(), node());
    187 
    188 	n.dim = dim;
    189 	n.boundary = deref(ctor(*median), dim);
    190 
    191 	int left = insert(first, split, ctor);
    192 	nodes[i].left = left;
    193 	int right = insert(split, last, ctor);
    194 	nodes[i].right = right;
    195 
    196 	return i;
    197       }
    198     }
    199   }
    200 
    201   // run to leaf; linear search for p;
    202   // if found, remove paths to empty leaves on unwind
    203   bool remove(int *i, const __valuetype & p) {
    204     if (*i == -1)
    205       return false;
    206     node & n = nodes[*i];
    207     bool r;
    208 
    209     if (n.dim >= 0) { // node
    210       if (deref(p, n.dim) <= n.boundary) // left
    211 	r = remove(&n.left, p);
    212       else // right
    213 	r = remove(&n.right, p);
    214 
    215       // if terminal, remove this node
    216       if (n.left == -1 && n.right == -1)
    217 	*i = -1;
    218 
    219       return r;
    220     } else { // leaf
    221       if (n.value == p) {
    222 	*i = n.right;
    223 	return true;
    224       } else
    225 	return remove(&n.right, p);
    226     }
    227   }
    228 
    229 public:
    230   struct identity_ctor {
    231     const __valuetype & operator() (const __valuetype & rhs) const {
    232       return rhs;
    233     }
    234   };
    235 
    236   // initialize an empty tree
    237   CvKDTree(__deref _deref = __deref())
    238     : deref(_deref), root_node(-1) {
    239   }
    240   // given points, initialize a balanced tree
    241   CvKDTree(__valuetype * first, __valuetype * last, int _point_dim,
    242 	   __deref _deref = __deref())
    243     : deref(_deref) {
    244     set_data(first, last, _point_dim, identity_ctor());
    245   }
    246   // given points, initialize a balanced tree
    247   template < class __instype, class __valuector >
    248   CvKDTree(__instype * first, __instype * last, int _point_dim,
    249 	   __valuector ctor, __deref _deref = __deref())
    250     : deref(_deref) {
    251     set_data(first, last, _point_dim, ctor);
    252   }
    253 
    254   void set_deref(__deref _deref) {
    255     deref = _deref;
    256   }
    257 
    258   void set_data(__valuetype * first, __valuetype * last, int _point_dim) {
    259     set_data(first, last, _point_dim, identity_ctor());
    260   }
    261   template < class __instype, class __valuector >
    262   void set_data(__instype * first, __instype * last, int _point_dim,
    263 		__valuector ctor) {
    264     point_dim = _point_dim;
    265     nodes.clear();
    266     nodes.reserve(last - first);
    267     root_node = insert(first, last, ctor);
    268   }
    269 
    270   int dims() const {
    271     return point_dim;
    272   }
    273 
    274   // remove the given point
    275   bool remove(const __valuetype & p) {
    276     return remove(&root_node, p);
    277   }
    278 
    279   void print() const {
    280     print(root_node);
    281   }
    282   void print(int i, int indent = 0) const {
    283     if (i == -1)
    284       return;
    285     for (int j = 0; j < indent; ++j)
    286       std::cout << " ";
    287     const node & n = nodes[i];
    288     if (n.dim >= 0) {
    289       std::cout << "node " << i << ", left " << nodes[i].left << ", right " <<
    290 	nodes[i].right << ", dim " << nodes[i].dim << ", boundary " <<
    291 	nodes[i].boundary << std::endl;
    292       print(n.left, indent + 3);
    293       print(n.right, indent + 3);
    294     } else
    295       std::cout << "leaf " << i << ", value = " << nodes[i].value << std::endl;
    296   }
    297 
    298   ////////////////////////////////////////////////////////////////////////////////////////
    299   // bbf search
    300 public:
    301   struct bbf_nn {		// info on found neighbors (approx k nearest)
    302     const __valuetype *p;	// nearest neighbor
    303     accum_type dist;		// distance from d to query point
    304     bbf_nn(const __valuetype & _p, accum_type _dist)
    305       : p(&_p), dist(_dist) {
    306     }
    307     bool operator<(const bbf_nn & rhs) const {
    308       return dist < rhs.dist;
    309     }
    310   };
    311   typedef std::vector < bbf_nn > bbf_nn_pqueue;
    312 private:
    313   struct bbf_node {		// info on branches not taken
    314     int node;			// corresponding node
    315     accum_type dist;		// minimum distance from bounds to query point
    316     bbf_node(int _node, accum_type _dist)
    317       : node(_node), dist(_dist) {
    318     }
    319     bool operator<(const bbf_node & rhs) const {
    320       return dist > rhs.dist;
    321     }
    322   };
    323   typedef std::vector < bbf_node > bbf_pqueue;
    324   mutable bbf_pqueue tmp_pq;
    325 
    326   // called for branches not taken, as bbf walks to leaf;
    327   // construct bbf_node given minimum distance to bounds of alternate branch
    328   void pq_alternate(int alt_n, bbf_pqueue & pq, scalar_type dist) const {
    329     if (alt_n == -1)
    330       return;
    331 
    332     // add bbf_node for alternate branch in priority queue
    333     pq.push_back(bbf_node(alt_n, dist));
    334     push_heap(pq.begin(), pq.end());
    335   }
    336 
    337   // called by bbf to walk to leaf;
    338   // takes one step down the tree towards query point d
    339   template < class __desctype >
    340   int bbf_branch(int i, const __desctype * d, bbf_pqueue & pq) const {
    341     const node & n = nodes[i];
    342     // push bbf_node with bounds of alternate branch, then branch
    343     if (d[n.dim] <= n.boundary) {	// left
    344       pq_alternate(n.right, pq, n.boundary - d[n.dim]);
    345       return n.left;
    346     } else {			// right
    347       pq_alternate(n.left, pq, d[n.dim] - n.boundary);
    348       return n.right;
    349     }
    350   }
    351 
    352   // compute euclidean distance between two points
    353   template < class __desctype >
    354   accum_type distance(const __desctype * d, const __valuetype & p) const {
    355     accum_type dist = 0;
    356     for (int j = 0; j < point_dim; ++j) {
    357       accum_type diff = accum_type(d[j]) - accum_type(deref(p, j));
    358       dist += diff * diff;
    359     } return (accum_type) sqrt(dist);
    360   }
    361 
    362   // called per candidate nearest neighbor; constructs new bbf_nn for
    363   // candidate and adds it to priority queue of all candidates; if
    364   // queue len exceeds k, drops the point furthest from query point d.
    365   template < class __desctype >
    366   void bbf_new_nn(bbf_nn_pqueue & nn_pq, int k,
    367 		  const __desctype * d, const __valuetype & p) const {
    368     bbf_nn nn(p, distance(d, p));
    369     if ((int) nn_pq.size() < k) {
    370       nn_pq.push_back(nn);
    371       push_heap(nn_pq.begin(), nn_pq.end());
    372     } else if (nn_pq[0].dist > nn.dist) {
    373       pop_heap(nn_pq.begin(), nn_pq.end());
    374       nn_pq.end()[-1] = nn;
    375       push_heap(nn_pq.begin(), nn_pq.end());
    376     }
    377     assert(nn_pq.size() < 2 || nn_pq[0].dist >= nn_pq[1].dist);
    378   }
    379 
    380 public:
    381   // finds (with high probability) the k nearest neighbors of d,
    382   // searching at most emax leaves/bins.
    383   // ret_nn_pq is an array containing the (at most) k nearest neighbors
    384   // (see bbf_nn structure def above).
    385   template < class __desctype >
    386   int find_nn_bbf(const __desctype * d,
    387 		  int k, int emax,
    388 		  bbf_nn_pqueue & ret_nn_pq) const {
    389     assert(k > 0);
    390     ret_nn_pq.clear();
    391 
    392     if (root_node == -1)
    393       return 0;
    394 
    395     // add root_node to bbf_node priority queue;
    396     // iterate while queue non-empty and emax>0
    397     tmp_pq.clear();
    398     tmp_pq.push_back(bbf_node(root_node, 0));
    399     while (tmp_pq.size() && emax > 0) {
    400 
    401       // from node nearest query point d, run to leaf
    402       pop_heap(tmp_pq.begin(), tmp_pq.end());
    403       bbf_node bbf(tmp_pq.end()[-1]);
    404       tmp_pq.erase(tmp_pq.end() - 1);
    405 
    406       int i;
    407       for (i = bbf.node;
    408 	   i != -1 && nodes[i].dim >= 0;
    409 	   i = bbf_branch(i, d, tmp_pq));
    410 
    411       if (i != -1) {
    412 
    413 	// add points in leaf/bin to ret_nn_pq
    414 	do {
    415 	  bbf_new_nn(ret_nn_pq, k, d, nodes[i].value);
    416 	} while (-1 != (i = nodes[i].right));
    417 
    418 	--emax;
    419       }
    420     }
    421 
    422     tmp_pq.clear();
    423     return ret_nn_pq.size();
    424   }
    425 
    426   ////////////////////////////////////////////////////////////////////////////////////////
    427   // orthogonal range search
    428 private:
    429   void find_ortho_range(int i, scalar_type * bounds_min,
    430 			scalar_type * bounds_max,
    431 			std::vector < __valuetype > &inbounds) const {
    432     if (i == -1)
    433       return;
    434     const node & n = nodes[i];
    435     if (n.dim >= 0) { // node
    436       if (bounds_min[n.dim] <= n.boundary)
    437 	find_ortho_range(n.left, bounds_min, bounds_max, inbounds);
    438       if (bounds_max[n.dim] > n.boundary)
    439 	find_ortho_range(n.right, bounds_min, bounds_max, inbounds);
    440     } else { // leaf
    441       do {
    442 	inbounds.push_back(nodes[i].value);
    443       } while (-1 != (i = nodes[i].right));
    444     }
    445   }
    446 public:
    447   // return all points that lie within the given bounds; inbounds is cleared
    448   int find_ortho_range(scalar_type * bounds_min,
    449 		       scalar_type * bounds_max,
    450 		       std::vector < __valuetype > &inbounds) const {
    451     inbounds.clear();
    452     find_ortho_range(root_node, bounds_min, bounds_max, inbounds);
    453     return inbounds.size();
    454   }
    455 };
    456 
    457 #endif // __cv_kdtree_h__
    458 
    459 // Local Variables:
    460 // mode:C++
    461 // End:
    462