Home | History | Annotate | Download | only in palette
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License
     15  */
     16 
     17 package com.android.internal.graphics.palette;
     18 
     19 import android.util.Log;
     20 
     21 import com.android.internal.graphics.ColorUtils;
     22 import com.android.internal.ml.clustering.KMeans;
     23 
     24 import java.util.ArrayList;
     25 import java.util.List;
     26 import java.util.Random;
     27 
     28 /**
     29  * A quantizer that uses k-means
     30  */
     31 public class VariationalKMeansQuantizer implements Quantizer {
     32 
     33     private static final String TAG = "KMeansQuantizer";
     34     private static final boolean DEBUG = false;
     35 
     36     /**
     37      * Clusters closer than this value will me merged.
     38      */
     39     private final float mMinClusterSqDistance;
     40 
     41     /**
     42      * K-means can get stuck in local optima, this can be avoided by
     43      * repeating it and getting the "best" execution.
     44      */
     45     private final int mInitializations;
     46 
     47     /**
     48      * Initialize KMeans with a fixed random state to have
     49      * consistent results across multiple runs.
     50      */
     51     private final KMeans mKMeans = new KMeans(new Random(0), 30, 0);
     52 
     53     private List<Palette.Swatch> mQuantizedColors;
     54 
     55     public VariationalKMeansQuantizer() {
     56         this(0.25f /* cluster distance */);
     57     }
     58 
     59     public VariationalKMeansQuantizer(float minClusterDistance) {
     60         this(minClusterDistance, 1 /* initializations */);
     61     }
     62 
     63     public VariationalKMeansQuantizer(float minClusterDistance, int initializations) {
     64         mMinClusterSqDistance = minClusterDistance * minClusterDistance;
     65         mInitializations = initializations;
     66     }
     67 
     68     /**
     69      * K-Means quantizer.
     70      *
     71      * @param pixels Pixels to quantize.
     72      * @param maxColors Maximum number of clusters to extract.
     73      * @param filters Colors that should be ignored
     74      */
     75     @Override
     76     public void quantize(int[] pixels, int maxColors, Palette.Filter[] filters) {
     77         // Start by converting all colors to HSL.
     78         // HLS is way more meaningful for clustering than RGB.
     79         final float[] hsl = {0, 0, 0};
     80         final float[][] hslPixels = new float[pixels.length][3];
     81         for (int i = 0; i < pixels.length; i++) {
     82             ColorUtils.colorToHSL(pixels[i], hsl);
     83             // Normalize hue so all values go from 0 to 1.
     84             hslPixels[i][0] = hsl[0] / 360f;
     85             hslPixels[i][1] = hsl[1];
     86             hslPixels[i][2] = hsl[2];
     87         }
     88 
     89         final List<KMeans.Mean> optimalMeans = getOptimalKMeans(maxColors, hslPixels);
     90 
     91         // Ideally we should run k-means again to merge clusters but it would be too expensive,
     92         // instead we just merge all clusters that are closer than a threshold.
     93         for (int i = 0; i < optimalMeans.size(); i++) {
     94             KMeans.Mean current = optimalMeans.get(i);
     95             float[] currentCentroid = current.getCentroid();
     96             for (int j = i + 1; j < optimalMeans.size(); j++) {
     97                 KMeans.Mean compareTo = optimalMeans.get(j);
     98                 float[] compareToCentroid = compareTo.getCentroid();
     99                 float sqDistance = KMeans.sqDistance(currentCentroid, compareToCentroid);
    100                 // Merge them
    101                 if (sqDistance < mMinClusterSqDistance) {
    102                     optimalMeans.remove(compareTo);
    103                     current.getItems().addAll(compareTo.getItems());
    104                     for (int k = 0; k < currentCentroid.length; k++) {
    105                         currentCentroid[k] += (compareToCentroid[k] - currentCentroid[k]) / 2.0;
    106                     }
    107                     j--;
    108                 }
    109             }
    110         }
    111 
    112         // Convert data to final format, de-normalizing the hue.
    113         mQuantizedColors = new ArrayList<>();
    114         for (KMeans.Mean mean : optimalMeans) {
    115             if (mean.getItems().size() == 0) {
    116                 continue;
    117             }
    118             float[] centroid = mean.getCentroid();
    119             mQuantizedColors.add(new Palette.Swatch(new float[]{
    120                     centroid[0] * 360f,
    121                     centroid[1],
    122                     centroid[2]
    123             }, mean.getItems().size()));
    124         }
    125     }
    126 
    127     private List<KMeans.Mean> getOptimalKMeans(int k, float[][] inputData) {
    128         List<KMeans.Mean> optimal = null;
    129         double optimalScore = -Double.MAX_VALUE;
    130         int runs = mInitializations;
    131         while (runs > 0) {
    132             if (DEBUG) {
    133                 Log.d(TAG, "k-means run: " + runs);
    134             }
    135             List<KMeans.Mean> means = mKMeans.predict(k, inputData);
    136             double score = KMeans.score(means);
    137             if (optimal == null || score > optimalScore) {
    138                 if (DEBUG) {
    139                     Log.d(TAG, "\tnew optimal score: " + score);
    140                 }
    141                 optimalScore = score;
    142                 optimal = means;
    143             }
    144             runs--;
    145         }
    146 
    147         return optimal;
    148     }
    149 
    150     @Override
    151     public List<Palette.Swatch> getQuantizedColors() {
    152         return mQuantizedColors;
    153     }
    154 }
    155