Home | History | Annotate | Download | only in learning
      1 /*
      2  * Copyright (C) 2011 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package android.bordeaux.learning;
     18 
     19 import android.util.Log;
     20 
     21 import java.io.ByteArrayInputStream;
     22 import java.io.ByteArrayOutputStream;
     23 import java.io.IOException;
     24 import java.io.ObjectInputStream;
     25 import java.io.ObjectOutputStream;
     26 import java.io.Serializable;
     27 import java.util.ArrayList;
     28 import java.util.Collections;
     29 import java.util.Comparator;
     30 import java.util.HashMap;
     31 import java.util.HashSet;
     32 import java.util.Iterator;
     33 import java.util.List;
     34 import java.util.Map;
     35 import java.util.Map.Entry;
     36 import java.util.concurrent.ConcurrentHashMap;
     37 /**
     38  * A histogram based predictor which records co-occurrences of applations with a speficic
     39  * feature, for example, location, * time of day, etc. The histogram is kept in a two level
     40  * hash table. The first level key is the feature value and the second level key is the app
     41  * id.
     42  */
     43 // TODOS:
     44 // 1. Use forgetting factor to downweight istances propotional to the time
     45 // 2. Different features could have different weights on prediction scores.
     46 // 3. Add function to remove sampleid (i.e. remove apps that are uninstalled).
     47 
     48 
     49 public class HistogramPredictor {
     50     final static String TAG = "HistogramPredictor";
     51 
     52     private HashMap<String, HistogramCounter> mPredictor =
     53             new HashMap<String, HistogramCounter>();
     54 
     55     private HashMap<String, Integer> mClassCounts = new HashMap<String, Integer>();
     56     private HashSet<String> mBlacklist = new HashSet<String>();
     57 
     58     private static final int MINIMAL_FEATURE_VALUE_COUNTS = 5;
     59     private static final int MINIMAL_APP_APPEARANCE_COUNTS = 5;
     60 
     61     // This parameter ranges from 0 to 1 which determines the effect of app prior.
     62     // When it is set to 0, app prior means completely neglected. When it is set to 1
     63     // the predictor is a standard naive bayes model.
     64     private static final int PRIOR_K_VALUE = 1;
     65 
     66     private static final String[] APP_BLACKLIST = {
     67         "com.android.contacts",
     68         "com.android.chrome",
     69         "com.android.providers.downloads.ui",
     70         "com.android.settings",
     71         "com.android.vending",
     72         "com.android.mms",
     73         "com.google.android.gm",
     74         "com.google.android.gallery3d",
     75         "com.google.android.apps.googlevoice",
     76     };
     77 
     78     public HistogramPredictor(String[] blackList) {
     79         for (String appName : blackList) {
     80             mBlacklist.add(appName);
     81         }
     82     }
     83 
     84     /*
     85      * This class keeps the histogram counts for each feature and provide the
     86      * joint probabilities of <feature, class>.
     87      */
     88     private class HistogramCounter {
     89         private HashMap<String, HashMap<String, Integer> > mCounter =
     90                 new HashMap<String, HashMap<String, Integer> >();
     91 
     92         public HistogramCounter() {
     93             mCounter.clear();
     94         }
     95 
     96         public void setCounter(HashMap<String, HashMap<String, Integer> > counter) {
     97             resetCounter();
     98             mCounter.putAll(counter);
     99         }
    100 
    101         public void resetCounter() {
    102             mCounter.clear();
    103         }
    104 
    105         public void addSample(String className, String featureValue) {
    106             HashMap<String, Integer> classCounts;
    107 
    108             if (!mCounter.containsKey(featureValue)) {
    109                 classCounts = new HashMap<String, Integer>();
    110                 mCounter.put(featureValue, classCounts);
    111             } else {
    112                 classCounts = mCounter.get(featureValue);
    113             }
    114             int count = (classCounts.containsKey(className)) ?
    115                     classCounts.get(className) + 1 : 1;
    116             classCounts.put(className, count);
    117         }
    118 
    119         public HashMap<String, Double> getClassScores(String featureValue) {
    120             HashMap<String, Double> classScores = new HashMap<String, Double>();
    121 
    122             if (mCounter.containsKey(featureValue)) {
    123                 int totalCount = 0;
    124                 for(Map.Entry<String, Integer> entry :
    125                         mCounter.get(featureValue).entrySet()) {
    126                     String app = entry.getKey();
    127                     int count = entry.getValue();
    128 
    129                     // For apps with counts less than or equal to one, we treated
    130                     // those as having count one. Hence their score, i.e. log(count)
    131                     // would be zero. classScroes stores only apps with non-zero scores.
    132                     // Note that totalCount also neglect app with single occurrence.
    133                     if (count > 1) {
    134                         double score = Math.log((double) count);
    135                         classScores.put(app, score);
    136                         totalCount += count;
    137                     }
    138                 }
    139                 if (totalCount < MINIMAL_FEATURE_VALUE_COUNTS) {
    140                     classScores.clear();
    141                 }
    142             }
    143             return classScores;
    144         }
    145 
    146         public byte[] getModel() {
    147             try {
    148                 ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
    149                 ObjectOutputStream objStream = new ObjectOutputStream(byteStream);
    150                 synchronized(mCounter) {
    151                     objStream.writeObject(mCounter);
    152                 }
    153                 byte[] bytes = byteStream.toByteArray();
    154                 return bytes;
    155             } catch (IOException e) {
    156                 throw new RuntimeException("Can't get model");
    157             }
    158         }
    159 
    160         public boolean setModel(final byte[] modelData) {
    161             mCounter.clear();
    162             HashMap<String, HashMap<String, Integer> > model;
    163 
    164             try {
    165                 ByteArrayInputStream input = new ByteArrayInputStream(modelData);
    166                 ObjectInputStream objStream = new ObjectInputStream(input);
    167                 model = (HashMap<String, HashMap<String, Integer> >) objStream.readObject();
    168             } catch (IOException e) {
    169                 throw new RuntimeException("Can't load model");
    170             } catch (ClassNotFoundException e) {
    171                 throw new RuntimeException("Learning class not found");
    172             }
    173 
    174             synchronized(mCounter) {
    175                 mCounter.putAll(model);
    176             }
    177 
    178             return true;
    179         }
    180 
    181 
    182         public HashMap<String, HashMap<String, Integer> > getCounter() {
    183             return mCounter;
    184         }
    185 
    186         public String toString() {
    187             String result = "";
    188             for (Map.Entry<String, HashMap<String, Integer> > entry :
    189                      mCounter.entrySet()) {
    190                 result += "{ " + entry.getKey() + " : " +
    191                     entry.getValue().toString() + " }";
    192             }
    193             return result;
    194         }
    195     }
    196 
    197     /*
    198      * Given a map of feature name -value pairs returns topK mostly likely apps to
    199      * be launched with corresponding likelihoods. If topK is set zero, it will return
    200      * the whole list.
    201      */
    202     public List<Map.Entry<String, Double> > findTopClasses(Map<String, String> features, int topK) {
    203         // Most sophisticated function in this class
    204         HashMap<String, Double> appScores = new HashMap<String, Double>();
    205         int validFeatureCount = 0;
    206 
    207         // compute all app scores
    208         for (Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
    209             String featureName = entry.getKey();
    210             HistogramCounter counter = entry.getValue();
    211 
    212             if (features.containsKey(featureName)) {
    213                 String featureValue = features.get(featureName);
    214                 HashMap<String, Double> scoreMap = counter.getClassScores(featureValue);
    215 
    216                 if (scoreMap.isEmpty()) {
    217                   continue;
    218                 }
    219                 validFeatureCount++;
    220 
    221                 for (Map.Entry<String, Double> item : scoreMap.entrySet()) {
    222                     String appName = item.getKey();
    223                     double appScore = item.getValue();
    224                     if (appScores.containsKey(appName)) {
    225                         appScore += appScores.get(appName);
    226                     }
    227                     appScores.put(appName, appScore);
    228                 }
    229             }
    230         }
    231 
    232         HashMap<String, Double> appCandidates = new HashMap<String, Double>();
    233         for (Map.Entry<String, Double> entry : appScores.entrySet()) {
    234             String appName = entry.getKey();
    235             if (mBlacklist.contains(appName)) {
    236                 Log.i(TAG, appName + " is in blacklist");
    237                 continue;
    238             }
    239             if (!mClassCounts.containsKey(appName)) {
    240                 throw new RuntimeException("class count error!");
    241             }
    242             int appCount = mClassCounts.get(appName);
    243             if (appCount < MINIMAL_APP_APPEARANCE_COUNTS) {
    244                 Log.i(TAG, appName + " doesn't have enough counts");
    245                 continue;
    246             }
    247 
    248             double appScore = entry.getValue();
    249             double appPrior = Math.log((double) appCount);
    250             appCandidates.put(appName,
    251                               appScore - appPrior * (validFeatureCount - PRIOR_K_VALUE));
    252         }
    253 
    254         // sort app scores
    255         List<Map.Entry<String, Double> > appList =
    256                new ArrayList<Map.Entry<String, Double> >(appCandidates.size());
    257         appList.addAll(appCandidates.entrySet());
    258         Collections.sort(appList, new  Comparator<Map.Entry<String, Double> >() {
    259             public int compare(Map.Entry<String, Double> o1,
    260                                Map.Entry<String, Double> o2) {
    261                 return o2.getValue().compareTo(o1.getValue());
    262             }
    263         });
    264 
    265         if (topK == 0) {
    266             topK = appList.size();
    267         }
    268         return appList.subList(0, Math.min(topK, appList.size()));
    269     }
    270 
    271     /*
    272      * Add a new observation of given sample id and features to the histograms
    273      */
    274     public void addSample(String sampleId, Map<String, String> features) {
    275         for (Map.Entry<String, String> entry : features.entrySet()) {
    276             String featureName = entry.getKey();
    277             String featureValue = entry.getValue();
    278 
    279             useFeature(featureName);
    280             HistogramCounter counter = mPredictor.get(featureName);
    281             counter.addSample(sampleId, featureValue);
    282         }
    283 
    284         int sampleCount = (mClassCounts.containsKey(sampleId)) ?
    285             mClassCounts.get(sampleId) + 1 : 1;
    286         mClassCounts.put(sampleId, sampleCount);
    287     }
    288 
    289     /*
    290      * reset predictor to a empty model
    291      */
    292     public void resetPredictor() {
    293         // TODO: not sure this step would reduce memory waste
    294         for (HistogramCounter counter : mPredictor.values()) {
    295             counter.resetCounter();
    296         }
    297         mPredictor.clear();
    298         mClassCounts.clear();
    299     }
    300 
    301     /*
    302      * convert the prediction model into a byte array
    303      */
    304     public byte[] getModel() {
    305         // TODO: convert model to a more memory efficient data structure.
    306         HashMap<String, HashMap<String, HashMap<String, Integer > > > model =
    307                 new HashMap<String, HashMap<String, HashMap<String, Integer > > >();
    308         for(Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) {
    309             model.put(entry.getKey(), entry.getValue().getCounter());
    310         }
    311 
    312         try {
    313             ByteArrayOutputStream byteStream = new ByteArrayOutputStream();
    314             ObjectOutputStream objStream = new ObjectOutputStream(byteStream);
    315             objStream.writeObject(model);
    316             byte[] bytes = byteStream.toByteArray();
    317             return bytes;
    318         } catch (IOException e) {
    319             throw new RuntimeException("Can't get model");
    320         }
    321     }
    322 
    323     /*
    324      * set the prediction model from a model data in the format of byte array
    325      */
    326     public boolean setModel(final byte[] modelData) {
    327         HashMap<String, HashMap<String, HashMap<String, Integer > > > model;
    328 
    329         try {
    330             ByteArrayInputStream input = new ByteArrayInputStream(modelData);
    331             ObjectInputStream objStream = new ObjectInputStream(input);
    332             model = (HashMap<String, HashMap<String, HashMap<String, Integer > > >)
    333                     objStream.readObject();
    334         } catch (IOException e) {
    335             throw new RuntimeException("Can't load model");
    336         } catch (ClassNotFoundException e) {
    337             throw new RuntimeException("Learning class not found");
    338         }
    339 
    340         resetPredictor();
    341         for (Map.Entry<String, HashMap<String, HashMap<String, Integer> > > entry :
    342                 model.entrySet()) {
    343             useFeature(entry.getKey());
    344             mPredictor.get(entry.getKey()).setCounter(entry.getValue());
    345         }
    346 
    347         // TODO: this is a temporary fix for now
    348         loadClassCounter();
    349 
    350         return true;
    351     }
    352 
    353     private void loadClassCounter() {
    354         String TIME_OF_WEEK = "Time of Week";
    355 
    356         if (!mPredictor.containsKey(TIME_OF_WEEK)) {
    357             throw new RuntimeException("Precition model error: missing Time of Week!");
    358         }
    359 
    360         HashMap<String, HashMap<String, Integer> > counter =
    361             mPredictor.get(TIME_OF_WEEK).getCounter();
    362 
    363         mClassCounts.clear();
    364         for (HashMap<String, Integer> map : counter.values()) {
    365             for (Map.Entry<String, Integer> entry : map.entrySet()) {
    366                 int classCount = entry.getValue();
    367                 String className = entry.getKey();
    368                 // mTotalClassCount += classCount;
    369 
    370                 if (mClassCounts.containsKey(className)) {
    371                     classCount += mClassCounts.get(className);
    372                 }
    373                 mClassCounts.put(className, classCount);
    374             }
    375         }
    376         Log.i(TAG, "class counts: " + mClassCounts);
    377     }
    378 
    379     private void useFeature(String featureName) {
    380         if (!mPredictor.containsKey(featureName)) {
    381             mPredictor.put(featureName, new HistogramCounter());
    382         }
    383     }
    384 }
    385