Home | History | Annotate | Download | only in ceres
      1 // Ceres Solver - A fast non-linear least squares minimizer
      2 // Copyright 2010, 2011, 2012 Google Inc. All rights reserved.
      3 // http://code.google.com/p/ceres-solver/
      4 //
      5 // Redistribution and use in source and binary forms, with or without
      6 // modification, are permitted provided that the following conditions are met:
      7 //
      8 // * Redistributions of source code must retain the above copyright notice,
      9 //   this list of conditions and the following disclaimer.
     10 // * Redistributions in binary form must reproduce the above copyright notice,
     11 //   this list of conditions and the following disclaimer in the documentation
     12 //   and/or other materials provided with the distribution.
     13 // * Neither the name of Google Inc. nor the names of its contributors may be
     14 //   used to endorse or promote products derived from this software without
     15 //   specific prior written permission.
     16 //
     17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
     18 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
     19 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
     20 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
     21 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
     22 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
     23 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
     24 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
     25 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
     26 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
     27 // POSSIBILITY OF SUCH DAMAGE.
     28 //
     29 // Author: David Gallup (dgallup (at) google.com)
     30 //         Sameer Agarwal (sameeragarwal (at) google.com)
     31 
     32 // This include must come before any #ifndef check on Ceres compile options.
     33 #include "ceres/internal/port.h"
     34 
     35 #ifndef CERES_NO_SUITESPARSE
     36 
     37 #include "ceres/canonical_views_clustering.h"
     38 
     39 #include "ceres/collections_port.h"
     40 #include "ceres/graph.h"
     41 #include "ceres/internal/macros.h"
     42 #include "ceres/map_util.h"
     43 #include "glog/logging.h"
     44 
     45 namespace ceres {
     46 namespace internal {
     47 
     48 typedef HashMap<int, int> IntMap;
     49 typedef HashSet<int> IntSet;
     50 
     51 class CanonicalViewsClustering {
     52  public:
     53   CanonicalViewsClustering() {}
     54 
     55   // Compute the canonical views clustering of the vertices of the
     56   // graph. centers will contain the vertices that are the identified
     57   // as the canonical views/cluster centers, and membership is a map
     58   // from vertices to cluster_ids. The i^th cluster center corresponds
     59   // to the i^th cluster. It is possible depending on the
     60   // configuration of the clustering algorithm that some of the
     61   // vertices may not be assigned to any cluster. In this case they
     62   // are assigned to a cluster with id = kInvalidClusterId.
     63   void ComputeClustering(const CanonicalViewsClusteringOptions& options,
     64                          const Graph<int>& graph,
     65                          vector<int>* centers,
     66                          IntMap* membership);
     67 
     68  private:
     69   void FindValidViews(IntSet* valid_views) const;
     70   double ComputeClusteringQualityDifference(const int candidate,
     71                                             const vector<int>& centers) const;
     72   void UpdateCanonicalViewAssignments(const int canonical_view);
     73   void ComputeClusterMembership(const vector<int>& centers,
     74                                 IntMap* membership) const;
     75 
     76   CanonicalViewsClusteringOptions options_;
     77   const Graph<int>* graph_;
     78   // Maps a view to its representative canonical view (its cluster
     79   // center).
     80   IntMap view_to_canonical_view_;
     81   // Maps a view to its similarity to its current cluster center.
     82   HashMap<int, double> view_to_canonical_view_similarity_;
     83   CERES_DISALLOW_COPY_AND_ASSIGN(CanonicalViewsClustering);
     84 };
     85 
     86 void ComputeCanonicalViewsClustering(
     87     const CanonicalViewsClusteringOptions& options,
     88     const Graph<int>& graph,
     89     vector<int>* centers,
     90     IntMap* membership) {
     91   time_t start_time = time(NULL);
     92   CanonicalViewsClustering cv;
     93   cv.ComputeClustering(options, graph, centers, membership);
     94   VLOG(2) << "Canonical views clustering time (secs): "
     95           << time(NULL) - start_time;
     96 }
     97 
     98 // Implementation of CanonicalViewsClustering
     99 void CanonicalViewsClustering::ComputeClustering(
    100     const CanonicalViewsClusteringOptions& options,
    101     const Graph<int>& graph,
    102     vector<int>* centers,
    103     IntMap* membership) {
    104   options_ = options;
    105   CHECK_NOTNULL(centers)->clear();
    106   CHECK_NOTNULL(membership)->clear();
    107   graph_ = &graph;
    108 
    109   IntSet valid_views;
    110   FindValidViews(&valid_views);
    111   while (valid_views.size() > 0) {
    112     // Find the next best canonical view.
    113     double best_difference = -std::numeric_limits<double>::max();
    114     int best_view = 0;
    115 
    116     // TODO(sameeragarwal): Make this loop multi-threaded.
    117     for (IntSet::const_iterator view = valid_views.begin();
    118          view != valid_views.end();
    119          ++view) {
    120       const double difference =
    121           ComputeClusteringQualityDifference(*view, *centers);
    122       if (difference > best_difference) {
    123         best_difference = difference;
    124         best_view = *view;
    125       }
    126     }
    127 
    128     CHECK_GT(best_difference, -std::numeric_limits<double>::max());
    129 
    130     // Add canonical view if quality improves, or if minimum is not
    131     // yet met, otherwise break.
    132     if ((best_difference <= 0) &&
    133         (centers->size() >= options_.min_views)) {
    134       break;
    135     }
    136 
    137     centers->push_back(best_view);
    138     valid_views.erase(best_view);
    139     UpdateCanonicalViewAssignments(best_view);
    140   }
    141 
    142   ComputeClusterMembership(*centers, membership);
    143 }
    144 
    145 // Return the set of vertices of the graph which have valid vertex
    146 // weights.
    147 void CanonicalViewsClustering::FindValidViews(
    148     IntSet* valid_views) const {
    149   const IntSet& views = graph_->vertices();
    150   for (IntSet::const_iterator view = views.begin();
    151        view != views.end();
    152        ++view) {
    153     if (graph_->VertexWeight(*view) != Graph<int>::InvalidWeight()) {
    154       valid_views->insert(*view);
    155     }
    156   }
    157 }
    158 
    159 // Computes the difference in the quality score if 'candidate' were
    160 // added to the set of canonical views.
    161 double CanonicalViewsClustering::ComputeClusteringQualityDifference(
    162     const int candidate,
    163     const vector<int>& centers) const {
    164   // View score.
    165   double difference =
    166       options_.view_score_weight * graph_->VertexWeight(candidate);
    167 
    168   // Compute how much the quality score changes if the candidate view
    169   // was added to the list of canonical views and its nearest
    170   // neighbors became members of its cluster.
    171   const IntSet& neighbors = graph_->Neighbors(candidate);
    172   for (IntSet::const_iterator neighbor = neighbors.begin();
    173        neighbor != neighbors.end();
    174        ++neighbor) {
    175     const double old_similarity =
    176         FindWithDefault(view_to_canonical_view_similarity_, *neighbor, 0.0);
    177     const double new_similarity = graph_->EdgeWeight(*neighbor, candidate);
    178     if (new_similarity > old_similarity) {
    179       difference += new_similarity - old_similarity;
    180     }
    181   }
    182 
    183   // Number of views penalty.
    184   difference -= options_.size_penalty_weight;
    185 
    186   // Orthogonality.
    187   for (int i = 0; i < centers.size(); ++i) {
    188     difference -= options_.similarity_penalty_weight *
    189         graph_->EdgeWeight(centers[i], candidate);
    190   }
    191 
    192   return difference;
    193 }
    194 
    195 // Reassign views if they're more similar to the new canonical view.
    196 void CanonicalViewsClustering::UpdateCanonicalViewAssignments(
    197     const int canonical_view) {
    198   const IntSet& neighbors = graph_->Neighbors(canonical_view);
    199   for (IntSet::const_iterator neighbor = neighbors.begin();
    200        neighbor != neighbors.end();
    201        ++neighbor) {
    202     const double old_similarity =
    203         FindWithDefault(view_to_canonical_view_similarity_, *neighbor, 0.0);
    204     const double new_similarity =
    205         graph_->EdgeWeight(*neighbor, canonical_view);
    206     if (new_similarity > old_similarity) {
    207       view_to_canonical_view_[*neighbor] = canonical_view;
    208       view_to_canonical_view_similarity_[*neighbor] = new_similarity;
    209     }
    210   }
    211 }
    212 
    213 // Assign a cluster id to each view.
    214 void CanonicalViewsClustering::ComputeClusterMembership(
    215     const vector<int>& centers,
    216     IntMap* membership) const {
    217   CHECK_NOTNULL(membership)->clear();
    218 
    219   // The i^th cluster has cluster id i.
    220   IntMap center_to_cluster_id;
    221   for (int i = 0; i < centers.size(); ++i) {
    222     center_to_cluster_id[centers[i]] = i;
    223   }
    224 
    225   static const int kInvalidClusterId = -1;
    226 
    227   const IntSet& views = graph_->vertices();
    228   for (IntSet::const_iterator view = views.begin();
    229        view != views.end();
    230        ++view) {
    231     IntMap::const_iterator it =
    232         view_to_canonical_view_.find(*view);
    233     int cluster_id = kInvalidClusterId;
    234     if (it != view_to_canonical_view_.end()) {
    235       cluster_id = FindOrDie(center_to_cluster_id, it->second);
    236     }
    237 
    238     InsertOrDie(membership, *view, cluster_id);
    239   }
    240 }
    241 
    242 }  // namespace internal
    243 }  // namespace ceres
    244 
    245 #endif  // CERES_NO_SUITESPARSE
    246