Home | History | Annotate | Download | only in clustering
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package com.android.internal.ml.clustering;
     18 
     19 import static org.junit.Assert.assertEquals;
     20 import static org.junit.Assert.assertTrue;
     21 
     22 import android.annotation.SuppressLint;
     23 import android.support.test.filters.SmallTest;
     24 import android.support.test.runner.AndroidJUnit4;
     25 
     26 import org.junit.Assert;
     27 import org.junit.Before;
     28 import org.junit.Test;
     29 import org.junit.runner.RunWith;
     30 
     31 import java.util.Arrays;
     32 import java.util.List;
     33 import java.util.Random;
     34 
     35 @SmallTest
     36 @RunWith(AndroidJUnit4.class)
     37 public class KMeansTest {
     38 
     39     // Error tolerance (epsilon)
     40     private static final double EPS = 0.01;
     41 
     42     private KMeans mKMeans;
     43 
     44     @Before
     45     public void setUp() {
     46         // Setup with a random seed to have predictable results
     47         mKMeans = new KMeans(new Random(0), 30, 0);
     48     }
     49 
     50     @Test
     51     public void getCheckDataSanityTest() {
     52         try {
     53             mKMeans.checkDataSetSanity(new float[][] {
     54                     {0, 1, 2},
     55                     {1, 2, 3}
     56             });
     57         } catch (IllegalArgumentException e) {
     58             Assert.fail("Valid data didn't pass sanity check");
     59         }
     60 
     61         try {
     62             mKMeans.checkDataSetSanity(new float[][] {
     63                     null,
     64                     {1, 2, 3}
     65             });
     66             Assert.fail("Data has null items and passed");
     67         } catch (IllegalArgumentException e) {}
     68 
     69         try {
     70             mKMeans.checkDataSetSanity(new float[][] {
     71                     {0, 1, 2, 4},
     72                     {1, 2, 3}
     73             });
     74             Assert.fail("Data has invalid shape and passed");
     75         } catch (IllegalArgumentException e) {}
     76 
     77         try {
     78             mKMeans.checkDataSetSanity(null);
     79             Assert.fail("Null data should throw exception");
     80         } catch (IllegalArgumentException e) {}
     81     }
     82 
     83     @Test
     84     public void sqDistanceTest() {
     85         float a[] = {4, 10};
     86         float b[] = {5, 2};
     87         float sqDist = (float) (Math.pow(a[0] - b[0], 2) + Math.pow(a[1] - b[1], 2));
     88 
     89         assertEquals("Squared distance not valid", mKMeans.sqDistance(a, b), sqDist, EPS);
     90     }
     91 
     92     @Test
     93     public void nearestMeanTest() {
     94         KMeans.Mean meanA = new KMeans.Mean(0, 1);
     95         KMeans.Mean meanB = new KMeans.Mean(1, 1);
     96         List<KMeans.Mean> means = Arrays.asList(meanA, meanB);
     97 
     98         KMeans.Mean nearest = mKMeans.nearestMean(new float[] {1, 1}, means);
     99 
    100         assertEquals("Unexpected nearest mean for point {1, 1}", nearest, meanB);
    101     }
    102 
    103     @SuppressLint("DefaultLocale")
    104     @Test
    105     public void scoreTest() {
    106         List<KMeans.Mean> closeMeans = Arrays.asList(new KMeans.Mean(0, 0.1f, 0.1f),
    107                 new KMeans.Mean(0, 0.1f, 0.15f),
    108                 new KMeans.Mean(0.1f, 0.2f, 0.1f));
    109         List<KMeans.Mean> farMeans = Arrays.asList(new KMeans.Mean(0, 0, 0),
    110                 new KMeans.Mean(0, 0.5f, 0.5f),
    111                 new KMeans.Mean(1, 0.9f, 0.9f));
    112 
    113         double closeScore = KMeans.score(closeMeans);
    114         double farScore = KMeans.score(farMeans);
    115         assertTrue(String.format("Score of well distributed means should be greater than "
    116                 + "close means but got: %f, %f", farScore, closeScore), farScore > closeScore);
    117     }
    118 
    119     @Test
    120     public void predictTest() {
    121         float[] expectedCentroid1 = {1, 1, 1};
    122         float[] expectedCentroid2 = {0, 0, 0};
    123         float[][] X = new float[][] {
    124                 {1, 1, 1},
    125                 {1, 1, 1},
    126                 {1, 1, 1},
    127                 {0, 0, 0},
    128                 {0, 0, 0},
    129                 {0, 0, 0},
    130         };
    131 
    132         final int numClusters = 2;
    133 
    134         // Here we assume that we won't get stuck into a local optima.
    135         // It's fine because we're seeding a random, we won't ever have
    136         // unstable results but in real life we need multiple initialization
    137         // and score comparison
    138         List<KMeans.Mean> means = mKMeans.predict(numClusters, X);
    139 
    140         assertEquals("Expected number of clusters is invalid", numClusters, means.size());
    141 
    142         boolean exists1 = false, exists2 = false;
    143         for (KMeans.Mean mean : means) {
    144             if (Arrays.equals(mean.getCentroid(), expectedCentroid1)) {
    145                 exists1 = true;
    146             } else if (Arrays.equals(mean.getCentroid(), expectedCentroid2)) {
    147                 exists2 = true;
    148             } else {
    149                 throw new AssertionError("Unexpected mean: " + mean);
    150             }
    151         }
    152         assertTrue("Expected means were not predicted, got: " + means,
    153                 exists1 && exists2);
    154     }
    155 }
    156