Home | History | Annotate | Download | only in clustering
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License
     15  */
     16 
     17 package com.android.internal.ml.clustering;
     18 
     19 import android.annotation.NonNull;
     20 import android.util.Log;
     21 
     22 import com.android.internal.annotations.VisibleForTesting;
     23 
     24 import java.util.ArrayList;
     25 import java.util.Arrays;
     26 import java.util.List;
     27 import java.util.Random;
     28 
     29 /**
     30  * Simple K-Means implementation
     31  */
     32 public class KMeans {
     33 
     34     private static final boolean DEBUG = false;
     35     private static final String TAG = "KMeans";
     36     private final Random mRandomState;
     37     private final int mMaxIterations;
     38     private float mSqConvergenceEpsilon;
     39 
     40     public KMeans() {
     41         this(new Random());
     42     }
     43 
     44     public KMeans(Random random) {
     45         this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */);
     46     }
     47     public KMeans(Random random, int maxIterations, float convergenceEpsilon) {
     48         mRandomState = random;
     49         mMaxIterations = maxIterations;
     50         mSqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon;
     51     }
     52 
     53     /**
     54      * Runs k-means on the input data (X) trying to find k means.
     55      *
     56      * K-Means is known for getting stuck into local optima, so you might
     57      * want to run it multiple time and argmax on {@link KMeans#score(List)}
     58      *
     59      * @param k The number of points to return.
     60      * @param inputData Input data.
     61      * @return An array of k Means, each representing a centroid and data points that belong to it.
     62      */
     63     public List<Mean> predict(final int k, final float[][] inputData) {
     64         checkDataSetSanity(inputData);
     65         int dimension = inputData[0].length;
     66 
     67         final ArrayList<Mean> means = new ArrayList<>();
     68         for (int i = 0; i < k; i++) {
     69             Mean m = new Mean(dimension);
     70             for (int j = 0; j < dimension; j++) {
     71                 m.mCentroid[j] = mRandomState.nextFloat();
     72             }
     73             means.add(m);
     74         }
     75 
     76         // Iterate until we converge or run out of iterations
     77         boolean converged = false;
     78         for (int i = 0; i < mMaxIterations; i++) {
     79             converged = step(means, inputData);
     80             if (converged) {
     81                 if (DEBUG) Log.d(TAG, "Converged at iteration: " + i);
     82                 break;
     83             }
     84         }
     85         if (!converged && DEBUG) Log.d(TAG, "Did not converge");
     86 
     87         return means;
     88     }
     89 
     90     /**
     91      * Score calculates the inertia between means.
     92      * This can be considered as an E step of an EM algorithm.
     93      *
     94      * @param means Means to use when calculating score.
     95      * @return The score
     96      */
     97     public static double score(@NonNull List<Mean> means) {
     98         double score = 0;
     99         final int meansSize = means.size();
    100         for (int i = 0; i < meansSize; i++) {
    101             Mean mean = means.get(i);
    102             for (int j = 0; j < meansSize; j++) {
    103                 Mean compareTo = means.get(j);
    104                 if (mean == compareTo) {
    105                     continue;
    106                 }
    107                 double distance = Math.sqrt(sqDistance(mean.mCentroid, compareTo.mCentroid));
    108                 score += distance;
    109             }
    110         }
    111         return score;
    112     }
    113 
    114     @VisibleForTesting
    115     public void checkDataSetSanity(float[][] inputData) {
    116         if (inputData == null) {
    117             throw new IllegalArgumentException("Data set is null.");
    118         } else if (inputData.length == 0) {
    119             throw new IllegalArgumentException("Data set is empty.");
    120         } else if (inputData[0] == null) {
    121             throw new IllegalArgumentException("Bad data set format.");
    122         }
    123 
    124         final int dimension = inputData[0].length;
    125         final int length = inputData.length;
    126         for (int i = 1; i < length; i++) {
    127             if (inputData[i] == null || inputData[i].length != dimension) {
    128                 throw new IllegalArgumentException("Bad data set format.");
    129             }
    130         }
    131     }
    132 
    133     /**
    134      * K-Means iteration.
    135      *
    136      * @param means Current means
    137      * @param inputData Input data
    138      * @return True if data set converged
    139      */
    140     private boolean step(final ArrayList<Mean> means, final float[][] inputData) {
    141 
    142         // Clean up the previous state because we need to compute
    143         // which point belongs to each mean again.
    144         for (int i = means.size() - 1; i >= 0; i--) {
    145             final Mean mean = means.get(i);
    146             mean.mClosestItems.clear();
    147         }
    148         for (int i = inputData.length - 1; i >= 0; i--) {
    149             final float[] current = inputData[i];
    150             final Mean nearest = nearestMean(current, means);
    151             nearest.mClosestItems.add(current);
    152         }
    153 
    154         boolean converged = true;
    155         // Move each mean towards the nearest data set points
    156         for (int i = means.size() - 1; i >= 0; i--) {
    157             final Mean mean = means.get(i);
    158             if (mean.mClosestItems.size() == 0) {
    159                 continue;
    160             }
    161 
    162             // Compute the new mean centroid:
    163             //   1. Sum all all points
    164             //   2. Average them
    165             final float[] oldCentroid = mean.mCentroid;
    166             mean.mCentroid = new float[oldCentroid.length];
    167             for (int j = 0; j < mean.mClosestItems.size(); j++) {
    168                 // Update each centroid component
    169                 for (int p = 0; p < mean.mCentroid.length; p++) {
    170                     mean.mCentroid[p] += mean.mClosestItems.get(j)[p];
    171                 }
    172             }
    173             for (int j = 0; j < mean.mCentroid.length; j++) {
    174                 mean.mCentroid[j] /= mean.mClosestItems.size();
    175             }
    176 
    177             // We converged if the centroid didn't move for any of the means.
    178             if (sqDistance(oldCentroid, mean.mCentroid) > mSqConvergenceEpsilon) {
    179                 converged = false;
    180             }
    181         }
    182         return converged;
    183     }
    184 
    185     @VisibleForTesting
    186     public static Mean nearestMean(float[] point, List<Mean> means) {
    187         Mean nearest = null;
    188         float nearestDistance = Float.MAX_VALUE;
    189 
    190         final int meanCount = means.size();
    191         for (int i = 0; i < meanCount; i++) {
    192             Mean next = means.get(i);
    193             // We don't need the sqrt when comparing distances in euclidean space
    194             // because they exist on both sides of the equation and cancel each other out.
    195             float nextDistance = sqDistance(point, next.mCentroid);
    196             if (nextDistance < nearestDistance) {
    197                 nearest = next;
    198                 nearestDistance = nextDistance;
    199             }
    200         }
    201         return nearest;
    202     }
    203 
    204     @VisibleForTesting
    205     public static float sqDistance(float[] a, float[] b) {
    206         float dist = 0;
    207         final int length = a.length;
    208         for (int i = 0; i < length; i++) {
    209             dist += (a[i] - b[i]) * (a[i] - b[i]);
    210         }
    211         return dist;
    212     }
    213 
    214     /**
    215      * Definition of a mean, contains a centroid and points on its cluster.
    216      */
    217     public static class Mean {
    218         float[] mCentroid;
    219         final ArrayList<float[]> mClosestItems = new ArrayList<>();
    220 
    221         public Mean(int dimension) {
    222             mCentroid = new float[dimension];
    223         }
    224 
    225         public Mean(float ...centroid) {
    226             mCentroid = centroid;
    227         }
    228 
    229         public float[] getCentroid() {
    230             return mCentroid;
    231         }
    232 
    233         public List<float[]> getItems() {
    234             return mClosestItems;
    235         }
    236 
    237         @Override
    238         public String toString() {
    239             return "Mean(centroid: " + Arrays.toString(mCentroid) + ", size: "
    240                     + mClosestItems.size() + ")";
    241         }
    242     }
    243 }
    244