1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.internal.ml.clustering; 18 19 import static org.junit.Assert.assertEquals; 20 import static org.junit.Assert.assertTrue; 21 22 import android.annotation.SuppressLint; 23 import android.support.test.filters.SmallTest; 24 import android.support.test.runner.AndroidJUnit4; 25 26 import org.junit.Assert; 27 import org.junit.Before; 28 import org.junit.Test; 29 import org.junit.runner.RunWith; 30 31 import java.util.Arrays; 32 import java.util.List; 33 import java.util.Random; 34 35 @SmallTest 36 @RunWith(AndroidJUnit4.class) 37 public class KMeansTest { 38 39 // Error tolerance (epsilon) 40 private static final double EPS = 0.01; 41 42 private KMeans mKMeans; 43 44 @Before 45 public void setUp() { 46 // Setup with a random seed to have predictable results 47 mKMeans = new KMeans(new Random(0), 30, 0); 48 } 49 50 @Test 51 public void getCheckDataSanityTest() { 52 try { 53 mKMeans.checkDataSetSanity(new float[][] { 54 {0, 1, 2}, 55 {1, 2, 3} 56 }); 57 } catch (IllegalArgumentException e) { 58 Assert.fail("Valid data didn't pass sanity check"); 59 } 60 61 try { 62 mKMeans.checkDataSetSanity(new float[][] { 63 null, 64 {1, 2, 3} 65 }); 66 Assert.fail("Data has null items and passed"); 67 } catch (IllegalArgumentException e) {} 68 69 try { 70 mKMeans.checkDataSetSanity(new float[][] { 71 {0, 1, 2, 4}, 72 {1, 2, 3} 73 }); 74 Assert.fail("Data has invalid shape and passed"); 75 } catch (IllegalArgumentException e) {} 76 77 try { 78 mKMeans.checkDataSetSanity(null); 79 Assert.fail("Null data should throw exception"); 80 } catch (IllegalArgumentException e) {} 81 } 82 83 @Test 84 public void sqDistanceTest() { 85 float a[] = {4, 10}; 86 float b[] = {5, 2}; 87 float sqDist = (float) (Math.pow(a[0] - b[0], 2) + Math.pow(a[1] - b[1], 2)); 88 89 assertEquals("Squared distance not valid", mKMeans.sqDistance(a, b), sqDist, EPS); 90 } 91 92 @Test 93 public void nearestMeanTest() { 94 KMeans.Mean meanA = new KMeans.Mean(0, 1); 95 KMeans.Mean meanB = new KMeans.Mean(1, 1); 96 List<KMeans.Mean> means = Arrays.asList(meanA, meanB); 97 98 KMeans.Mean nearest = mKMeans.nearestMean(new float[] {1, 1}, means); 99 100 assertEquals("Unexpected nearest mean for point {1, 1}", nearest, meanB); 101 } 102 103 @SuppressLint("DefaultLocale") 104 @Test 105 public void scoreTest() { 106 List<KMeans.Mean> closeMeans = Arrays.asList(new KMeans.Mean(0, 0.1f, 0.1f), 107 new KMeans.Mean(0, 0.1f, 0.15f), 108 new KMeans.Mean(0.1f, 0.2f, 0.1f)); 109 List<KMeans.Mean> farMeans = Arrays.asList(new KMeans.Mean(0, 0, 0), 110 new KMeans.Mean(0, 0.5f, 0.5f), 111 new KMeans.Mean(1, 0.9f, 0.9f)); 112 113 double closeScore = KMeans.score(closeMeans); 114 double farScore = KMeans.score(farMeans); 115 assertTrue(String.format("Score of well distributed means should be greater than " 116 + "close means but got: %f, %f", farScore, closeScore), farScore > closeScore); 117 } 118 119 @Test 120 public void predictTest() { 121 float[] expectedCentroid1 = {1, 1, 1}; 122 float[] expectedCentroid2 = {0, 0, 0}; 123 float[][] X = new float[][] { 124 {1, 1, 1}, 125 {1, 1, 1}, 126 {1, 1, 1}, 127 {0, 0, 0}, 128 {0, 0, 0}, 129 {0, 0, 0}, 130 }; 131 132 final int numClusters = 2; 133 134 // Here we assume that we won't get stuck into a local optima. 135 // It's fine because we're seeding a random, we won't ever have 136 // unstable results but in real life we need multiple initialization 137 // and score comparison 138 List<KMeans.Mean> means = mKMeans.predict(numClusters, X); 139 140 assertEquals("Expected number of clusters is invalid", numClusters, means.size()); 141 142 boolean exists1 = false, exists2 = false; 143 for (KMeans.Mean mean : means) { 144 if (Arrays.equals(mean.getCentroid(), expectedCentroid1)) { 145 exists1 = true; 146 } else if (Arrays.equals(mean.getCentroid(), expectedCentroid2)) { 147 exists2 = true; 148 } else { 149 throw new AssertionError("Unexpected mean: " + mean); 150 } 151 } 152 assertTrue("Expected means were not predicted, got: " + means, 153 exists1 && exists2); 154 } 155 } 156