Home | History | Annotate | Download | only in clustering
      1 /*
      2  * Licensed to the Apache Software Foundation (ASF) under one or more
      3  * contributor license agreements.  See the NOTICE file distributed with
      4  * this work for additional information regarding copyright ownership.
      5  * The ASF licenses this file to You under the Apache License, Version 2.0
      6  * (the "License"); you may not use this file except in compliance with
      7  * the License.  You may obtain a copy of the License at
      8  *
      9  *      http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  */
     17 
     18 package org.apache.commons.math.stat.clustering;
     19 
     20 import java.util.ArrayList;
     21 import java.util.Collection;
     22 import java.util.List;
     23 import java.util.Random;
     24 
     25 import org.apache.commons.math.exception.ConvergenceException;
     26 import org.apache.commons.math.exception.util.LocalizedFormats;
     27 import org.apache.commons.math.stat.descriptive.moment.Variance;
     28 
     29 /**
     30  * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
     31  * @param <T> type of the points to cluster
     32  * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
     33  * @version $Revision: 1054333 $ $Date: 2011-01-02 01:34:58 +0100 (dim. 02 janv. 2011) $
     34  * @since 2.0
     35  */
     36 public class KMeansPlusPlusClusterer<T extends Clusterable<T>> {
     37 
     38     /** Strategies to use for replacing an empty cluster. */
     39     public static enum EmptyClusterStrategy {
     40 
     41         /** Split the cluster with largest distance variance. */
     42         LARGEST_VARIANCE,
     43 
     44         /** Split the cluster with largest number of points. */
     45         LARGEST_POINTS_NUMBER,
     46 
     47         /** Create a cluster around the point farthest from its centroid. */
     48         FARTHEST_POINT,
     49 
     50         /** Generate an error. */
     51         ERROR
     52 
     53     }
     54 
     55     /** Random generator for choosing initial centers. */
     56     private final Random random;
     57 
     58     /** Selected strategy for empty clusters. */
     59     private final EmptyClusterStrategy emptyStrategy;
     60 
     61     /** Build a clusterer.
     62      * <p>
     63      * The default strategy for handling empty clusters that may appear during
     64      * algorithm iterations is to split the cluster with largest distance variance.
     65      * </p>
     66      * @param random random generator to use for choosing initial centers
     67      */
     68     public KMeansPlusPlusClusterer(final Random random) {
     69         this(random, EmptyClusterStrategy.LARGEST_VARIANCE);
     70     }
     71 
     72     /** Build a clusterer.
     73      * @param random random generator to use for choosing initial centers
     74      * @param emptyStrategy strategy to use for handling empty clusters that
     75      * may appear during algorithm iterations
     76      * @since 2.2
     77      */
     78     public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) {
     79         this.random        = random;
     80         this.emptyStrategy = emptyStrategy;
     81     }
     82 
     83     /**
     84      * Runs the K-means++ clustering algorithm.
     85      *
     86      * @param points the points to cluster
     87      * @param k the number of clusters to split the data into
     88      * @param maxIterations the maximum number of iterations to run the algorithm
     89      *     for.  If negative, no maximum will be used
     90      * @return a list of clusters containing the points
     91      */
     92     public List<Cluster<T>> cluster(final Collection<T> points,
     93                                     final int k, final int maxIterations) {
     94         // create the initial clusters
     95         List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
     96         assignPointsToClusters(clusters, points);
     97 
     98         // iterate through updating the centers until we're done
     99         final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
    100         for (int count = 0; count < max; count++) {
    101             boolean clusteringChanged = false;
    102             List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
    103             for (final Cluster<T> cluster : clusters) {
    104                 final T newCenter;
    105                 if (cluster.getPoints().isEmpty()) {
    106                     switch (emptyStrategy) {
    107                         case LARGEST_VARIANCE :
    108                             newCenter = getPointFromLargestVarianceCluster(clusters);
    109                             break;
    110                         case LARGEST_POINTS_NUMBER :
    111                             newCenter = getPointFromLargestNumberCluster(clusters);
    112                             break;
    113                         case FARTHEST_POINT :
    114                             newCenter = getFarthestPoint(clusters);
    115                             break;
    116                         default :
    117                             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
    118                     }
    119                     clusteringChanged = true;
    120                 } else {
    121                     newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
    122                     if (!newCenter.equals(cluster.getCenter())) {
    123                         clusteringChanged = true;
    124                     }
    125                 }
    126                 newClusters.add(new Cluster<T>(newCenter));
    127             }
    128             if (!clusteringChanged) {
    129                 return clusters;
    130             }
    131             assignPointsToClusters(newClusters, points);
    132             clusters = newClusters;
    133         }
    134         return clusters;
    135     }
    136 
    137     /**
    138      * Adds the given points to the closest {@link Cluster}.
    139      *
    140      * @param <T> type of the points to cluster
    141      * @param clusters the {@link Cluster}s to add the points to
    142      * @param points the points to add to the given {@link Cluster}s
    143      */
    144     private static <T extends Clusterable<T>> void
    145         assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) {
    146         for (final T p : points) {
    147             Cluster<T> cluster = getNearestCluster(clusters, p);
    148             cluster.addPoint(p);
    149         }
    150     }
    151 
    152     /**
    153      * Use K-means++ to choose the initial centers.
    154      *
    155      * @param <T> type of the points to cluster
    156      * @param points the points to choose the initial centers from
    157      * @param k the number of centers to choose
    158      * @param random random generator to use
    159      * @return the initial centers
    160      */
    161     private static <T extends Clusterable<T>> List<Cluster<T>>
    162         chooseInitialCenters(final Collection<T> points, final int k, final Random random) {
    163 
    164         final List<T> pointSet = new ArrayList<T>(points);
    165         final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>();
    166 
    167         // Choose one center uniformly at random from among the data points.
    168         final T firstPoint = pointSet.remove(random.nextInt(pointSet.size()));
    169         resultSet.add(new Cluster<T>(firstPoint));
    170 
    171         final double[] dx2 = new double[pointSet.size()];
    172         while (resultSet.size() < k) {
    173             // For each data point x, compute D(x), the distance between x and
    174             // the nearest center that has already been chosen.
    175             int sum = 0;
    176             for (int i = 0; i < pointSet.size(); i++) {
    177                 final T p = pointSet.get(i);
    178                 final Cluster<T> nearest = getNearestCluster(resultSet, p);
    179                 final double d = p.distanceFrom(nearest.getCenter());
    180                 sum += d * d;
    181                 dx2[i] = sum;
    182             }
    183 
    184             // Add one new data point as a center. Each point x is chosen with
    185             // probability proportional to D(x)2
    186             final double r = random.nextDouble() * sum;
    187             for (int i = 0 ; i < dx2.length; i++) {
    188                 if (dx2[i] >= r) {
    189                     final T p = pointSet.remove(i);
    190                     resultSet.add(new Cluster<T>(p));
    191                     break;
    192                 }
    193             }
    194         }
    195 
    196         return resultSet;
    197 
    198     }
    199 
    200     /**
    201      * Get a random point from the {@link Cluster} with the largest distance variance.
    202      *
    203      * @param clusters the {@link Cluster}s to search
    204      * @return a random point from the selected cluster
    205      */
    206     private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters) {
    207 
    208         double maxVariance = Double.NEGATIVE_INFINITY;
    209         Cluster<T> selected = null;
    210         for (final Cluster<T> cluster : clusters) {
    211             if (!cluster.getPoints().isEmpty()) {
    212 
    213                 // compute the distance variance of the current cluster
    214                 final T center = cluster.getCenter();
    215                 final Variance stat = new Variance();
    216                 for (final T point : cluster.getPoints()) {
    217                     stat.increment(point.distanceFrom(center));
    218                 }
    219                 final double variance = stat.getResult();
    220 
    221                 // select the cluster with the largest variance
    222                 if (variance > maxVariance) {
    223                     maxVariance = variance;
    224                     selected = cluster;
    225                 }
    226 
    227             }
    228         }
    229 
    230         // did we find at least one non-empty cluster ?
    231         if (selected == null) {
    232             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
    233         }
    234 
    235         // extract a random point from the cluster
    236         final List<T> selectedPoints = selected.getPoints();
    237         return selectedPoints.remove(random.nextInt(selectedPoints.size()));
    238 
    239     }
    240 
    241     /**
    242      * Get a random point from the {@link Cluster} with the largest number of points
    243      *
    244      * @param clusters the {@link Cluster}s to search
    245      * @return a random point from the selected cluster
    246      */
    247     private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) {
    248 
    249         int maxNumber = 0;
    250         Cluster<T> selected = null;
    251         for (final Cluster<T> cluster : clusters) {
    252 
    253             // get the number of points of the current cluster
    254             final int number = cluster.getPoints().size();
    255 
    256             // select the cluster with the largest number of points
    257             if (number > maxNumber) {
    258                 maxNumber = number;
    259                 selected = cluster;
    260             }
    261 
    262         }
    263 
    264         // did we find at least one non-empty cluster ?
    265         if (selected == null) {
    266             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
    267         }
    268 
    269         // extract a random point from the cluster
    270         final List<T> selectedPoints = selected.getPoints();
    271         return selectedPoints.remove(random.nextInt(selectedPoints.size()));
    272 
    273     }
    274 
    275     /**
    276      * Get the point farthest to its cluster center
    277      *
    278      * @param clusters the {@link Cluster}s to search
    279      * @return point farthest to its cluster center
    280      */
    281     private T getFarthestPoint(final Collection<Cluster<T>> clusters) {
    282 
    283         double maxDistance = Double.NEGATIVE_INFINITY;
    284         Cluster<T> selectedCluster = null;
    285         int selectedPoint = -1;
    286         for (final Cluster<T> cluster : clusters) {
    287 
    288             // get the farthest point
    289             final T center = cluster.getCenter();
    290             final List<T> points = cluster.getPoints();
    291             for (int i = 0; i < points.size(); ++i) {
    292                 final double distance = points.get(i).distanceFrom(center);
    293                 if (distance > maxDistance) {
    294                     maxDistance     = distance;
    295                     selectedCluster = cluster;
    296                     selectedPoint   = i;
    297                 }
    298             }
    299 
    300         }
    301 
    302         // did we find at least one non-empty cluster ?
    303         if (selectedCluster == null) {
    304             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
    305         }
    306 
    307         return selectedCluster.getPoints().remove(selectedPoint);
    308 
    309     }
    310 
    311     /**
    312      * Returns the nearest {@link Cluster} to the given point
    313      *
    314      * @param <T> type of the points to cluster
    315      * @param clusters the {@link Cluster}s to search
    316      * @param point the point to find the nearest {@link Cluster} for
    317      * @return the nearest {@link Cluster} to the given point
    318      */
    319     private static <T extends Clusterable<T>> Cluster<T>
    320         getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
    321         double minDistance = Double.MAX_VALUE;
    322         Cluster<T> minCluster = null;
    323         for (final Cluster<T> c : clusters) {
    324             final double distance = point.distanceFrom(c.getCenter());
    325             if (distance < minDistance) {
    326                 minDistance = distance;
    327                 minCluster = c;
    328             }
    329         }
    330         return minCluster;
    331     }
    332 
    333 }
    334