1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.commons.math.stat.clustering; 19 20 import java.util.ArrayList; 21 import java.util.Collection; 22 import java.util.List; 23 import java.util.Random; 24 25 import org.apache.commons.math.exception.ConvergenceException; 26 import org.apache.commons.math.exception.util.LocalizedFormats; 27 import org.apache.commons.math.stat.descriptive.moment.Variance; 28 29 /** 30 * Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm. 31 * @param <T> type of the points to cluster 32 * @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a> 33 * @version $Revision: 1054333 $ $Date: 2011-01-02 01:34:58 +0100 (dim. 02 janv. 2011) $ 34 * @since 2.0 35 */ 36 public class KMeansPlusPlusClusterer<T extends Clusterable<T>> { 37 38 /** Strategies to use for replacing an empty cluster. */ 39 public static enum EmptyClusterStrategy { 40 41 /** Split the cluster with largest distance variance. */ 42 LARGEST_VARIANCE, 43 44 /** Split the cluster with largest number of points. */ 45 LARGEST_POINTS_NUMBER, 46 47 /** Create a cluster around the point farthest from its centroid. */ 48 FARTHEST_POINT, 49 50 /** Generate an error. */ 51 ERROR 52 53 } 54 55 /** Random generator for choosing initial centers. */ 56 private final Random random; 57 58 /** Selected strategy for empty clusters. */ 59 private final EmptyClusterStrategy emptyStrategy; 60 61 /** Build a clusterer. 62 * <p> 63 * The default strategy for handling empty clusters that may appear during 64 * algorithm iterations is to split the cluster with largest distance variance. 65 * </p> 66 * @param random random generator to use for choosing initial centers 67 */ 68 public KMeansPlusPlusClusterer(final Random random) { 69 this(random, EmptyClusterStrategy.LARGEST_VARIANCE); 70 } 71 72 /** Build a clusterer. 73 * @param random random generator to use for choosing initial centers 74 * @param emptyStrategy strategy to use for handling empty clusters that 75 * may appear during algorithm iterations 76 * @since 2.2 77 */ 78 public KMeansPlusPlusClusterer(final Random random, final EmptyClusterStrategy emptyStrategy) { 79 this.random = random; 80 this.emptyStrategy = emptyStrategy; 81 } 82 83 /** 84 * Runs the K-means++ clustering algorithm. 85 * 86 * @param points the points to cluster 87 * @param k the number of clusters to split the data into 88 * @param maxIterations the maximum number of iterations to run the algorithm 89 * for. If negative, no maximum will be used 90 * @return a list of clusters containing the points 91 */ 92 public List<Cluster<T>> cluster(final Collection<T> points, 93 final int k, final int maxIterations) { 94 // create the initial clusters 95 List<Cluster<T>> clusters = chooseInitialCenters(points, k, random); 96 assignPointsToClusters(clusters, points); 97 98 // iterate through updating the centers until we're done 99 final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations; 100 for (int count = 0; count < max; count++) { 101 boolean clusteringChanged = false; 102 List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>(); 103 for (final Cluster<T> cluster : clusters) { 104 final T newCenter; 105 if (cluster.getPoints().isEmpty()) { 106 switch (emptyStrategy) { 107 case LARGEST_VARIANCE : 108 newCenter = getPointFromLargestVarianceCluster(clusters); 109 break; 110 case LARGEST_POINTS_NUMBER : 111 newCenter = getPointFromLargestNumberCluster(clusters); 112 break; 113 case FARTHEST_POINT : 114 newCenter = getFarthestPoint(clusters); 115 break; 116 default : 117 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 118 } 119 clusteringChanged = true; 120 } else { 121 newCenter = cluster.getCenter().centroidOf(cluster.getPoints()); 122 if (!newCenter.equals(cluster.getCenter())) { 123 clusteringChanged = true; 124 } 125 } 126 newClusters.add(new Cluster<T>(newCenter)); 127 } 128 if (!clusteringChanged) { 129 return clusters; 130 } 131 assignPointsToClusters(newClusters, points); 132 clusters = newClusters; 133 } 134 return clusters; 135 } 136 137 /** 138 * Adds the given points to the closest {@link Cluster}. 139 * 140 * @param <T> type of the points to cluster 141 * @param clusters the {@link Cluster}s to add the points to 142 * @param points the points to add to the given {@link Cluster}s 143 */ 144 private static <T extends Clusterable<T>> void 145 assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) { 146 for (final T p : points) { 147 Cluster<T> cluster = getNearestCluster(clusters, p); 148 cluster.addPoint(p); 149 } 150 } 151 152 /** 153 * Use K-means++ to choose the initial centers. 154 * 155 * @param <T> type of the points to cluster 156 * @param points the points to choose the initial centers from 157 * @param k the number of centers to choose 158 * @param random random generator to use 159 * @return the initial centers 160 */ 161 private static <T extends Clusterable<T>> List<Cluster<T>> 162 chooseInitialCenters(final Collection<T> points, final int k, final Random random) { 163 164 final List<T> pointSet = new ArrayList<T>(points); 165 final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>(); 166 167 // Choose one center uniformly at random from among the data points. 168 final T firstPoint = pointSet.remove(random.nextInt(pointSet.size())); 169 resultSet.add(new Cluster<T>(firstPoint)); 170 171 final double[] dx2 = new double[pointSet.size()]; 172 while (resultSet.size() < k) { 173 // For each data point x, compute D(x), the distance between x and 174 // the nearest center that has already been chosen. 175 int sum = 0; 176 for (int i = 0; i < pointSet.size(); i++) { 177 final T p = pointSet.get(i); 178 final Cluster<T> nearest = getNearestCluster(resultSet, p); 179 final double d = p.distanceFrom(nearest.getCenter()); 180 sum += d * d; 181 dx2[i] = sum; 182 } 183 184 // Add one new data point as a center. Each point x is chosen with 185 // probability proportional to D(x)2 186 final double r = random.nextDouble() * sum; 187 for (int i = 0 ; i < dx2.length; i++) { 188 if (dx2[i] >= r) { 189 final T p = pointSet.remove(i); 190 resultSet.add(new Cluster<T>(p)); 191 break; 192 } 193 } 194 } 195 196 return resultSet; 197 198 } 199 200 /** 201 * Get a random point from the {@link Cluster} with the largest distance variance. 202 * 203 * @param clusters the {@link Cluster}s to search 204 * @return a random point from the selected cluster 205 */ 206 private T getPointFromLargestVarianceCluster(final Collection<Cluster<T>> clusters) { 207 208 double maxVariance = Double.NEGATIVE_INFINITY; 209 Cluster<T> selected = null; 210 for (final Cluster<T> cluster : clusters) { 211 if (!cluster.getPoints().isEmpty()) { 212 213 // compute the distance variance of the current cluster 214 final T center = cluster.getCenter(); 215 final Variance stat = new Variance(); 216 for (final T point : cluster.getPoints()) { 217 stat.increment(point.distanceFrom(center)); 218 } 219 final double variance = stat.getResult(); 220 221 // select the cluster with the largest variance 222 if (variance > maxVariance) { 223 maxVariance = variance; 224 selected = cluster; 225 } 226 227 } 228 } 229 230 // did we find at least one non-empty cluster ? 231 if (selected == null) { 232 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 233 } 234 235 // extract a random point from the cluster 236 final List<T> selectedPoints = selected.getPoints(); 237 return selectedPoints.remove(random.nextInt(selectedPoints.size())); 238 239 } 240 241 /** 242 * Get a random point from the {@link Cluster} with the largest number of points 243 * 244 * @param clusters the {@link Cluster}s to search 245 * @return a random point from the selected cluster 246 */ 247 private T getPointFromLargestNumberCluster(final Collection<Cluster<T>> clusters) { 248 249 int maxNumber = 0; 250 Cluster<T> selected = null; 251 for (final Cluster<T> cluster : clusters) { 252 253 // get the number of points of the current cluster 254 final int number = cluster.getPoints().size(); 255 256 // select the cluster with the largest number of points 257 if (number > maxNumber) { 258 maxNumber = number; 259 selected = cluster; 260 } 261 262 } 263 264 // did we find at least one non-empty cluster ? 265 if (selected == null) { 266 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 267 } 268 269 // extract a random point from the cluster 270 final List<T> selectedPoints = selected.getPoints(); 271 return selectedPoints.remove(random.nextInt(selectedPoints.size())); 272 273 } 274 275 /** 276 * Get the point farthest to its cluster center 277 * 278 * @param clusters the {@link Cluster}s to search 279 * @return point farthest to its cluster center 280 */ 281 private T getFarthestPoint(final Collection<Cluster<T>> clusters) { 282 283 double maxDistance = Double.NEGATIVE_INFINITY; 284 Cluster<T> selectedCluster = null; 285 int selectedPoint = -1; 286 for (final Cluster<T> cluster : clusters) { 287 288 // get the farthest point 289 final T center = cluster.getCenter(); 290 final List<T> points = cluster.getPoints(); 291 for (int i = 0; i < points.size(); ++i) { 292 final double distance = points.get(i).distanceFrom(center); 293 if (distance > maxDistance) { 294 maxDistance = distance; 295 selectedCluster = cluster; 296 selectedPoint = i; 297 } 298 } 299 300 } 301 302 // did we find at least one non-empty cluster ? 303 if (selectedCluster == null) { 304 throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS); 305 } 306 307 return selectedCluster.getPoints().remove(selectedPoint); 308 309 } 310 311 /** 312 * Returns the nearest {@link Cluster} to the given point 313 * 314 * @param <T> type of the points to cluster 315 * @param clusters the {@link Cluster}s to search 316 * @param point the point to find the nearest {@link Cluster} for 317 * @return the nearest {@link Cluster} to the given point 318 */ 319 private static <T extends Clusterable<T>> Cluster<T> 320 getNearestCluster(final Collection<Cluster<T>> clusters, final T point) { 321 double minDistance = Double.MAX_VALUE; 322 Cluster<T> minCluster = null; 323 for (final Cluster<T> c : clusters) { 324 final double distance = point.distanceFrom(c.getCenter()); 325 if (distance < minDistance) { 326 minDistance = distance; 327 minCluster = c; 328 } 329 } 330 return minCluster; 331 } 332 333 } 334