1 /* 2 * Copyright (C) 2011 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.bordeaux.learning; 18 19 import android.util.Log; 20 21 import java.io.ByteArrayInputStream; 22 import java.io.ByteArrayOutputStream; 23 import java.io.IOException; 24 import java.io.ObjectInputStream; 25 import java.io.ObjectOutputStream; 26 import java.io.Serializable; 27 import java.util.ArrayList; 28 import java.util.Collections; 29 import java.util.Comparator; 30 import java.util.HashMap; 31 import java.util.HashSet; 32 import java.util.Iterator; 33 import java.util.List; 34 import java.util.Map; 35 import java.util.Map.Entry; 36 import java.util.concurrent.ConcurrentHashMap; 37 /** 38 * A histogram based predictor which records co-occurrences of applations with a speficic 39 * feature, for example, location, * time of day, etc. The histogram is kept in a two level 40 * hash table. The first level key is the feature value and the second level key is the app 41 * id. 42 */ 43 // TODOS: 44 // 1. Use forgetting factor to downweight istances propotional to the time 45 // 2. Different features could have different weights on prediction scores. 46 // 3. Add function to remove sampleid (i.e. remove apps that are uninstalled). 47 48 49 public class HistogramPredictor { 50 final static String TAG = "HistogramPredictor"; 51 52 private HashMap<String, HistogramCounter> mPredictor = 53 new HashMap<String, HistogramCounter>(); 54 55 private HashMap<String, Integer> mClassCounts = new HashMap<String, Integer>(); 56 private HashSet<String> mBlacklist = new HashSet<String>(); 57 58 private static final int MINIMAL_FEATURE_VALUE_COUNTS = 5; 59 private static final int MINIMAL_APP_APPEARANCE_COUNTS = 5; 60 61 // This parameter ranges from 0 to 1 which determines the effect of app prior. 62 // When it is set to 0, app prior means completely neglected. When it is set to 1 63 // the predictor is a standard naive bayes model. 64 private static final int PRIOR_K_VALUE = 1; 65 66 private static final String[] APP_BLACKLIST = { 67 "com.android.contacts", 68 "com.android.chrome", 69 "com.android.providers.downloads.ui", 70 "com.android.settings", 71 "com.android.vending", 72 "com.android.mms", 73 "com.google.android.gm", 74 "com.google.android.gallery3d", 75 "com.google.android.apps.googlevoice", 76 }; 77 78 public HistogramPredictor(String[] blackList) { 79 for (String appName : blackList) { 80 mBlacklist.add(appName); 81 } 82 } 83 84 /* 85 * This class keeps the histogram counts for each feature and provide the 86 * joint probabilities of <feature, class>. 87 */ 88 private class HistogramCounter { 89 private HashMap<String, HashMap<String, Integer> > mCounter = 90 new HashMap<String, HashMap<String, Integer> >(); 91 92 public HistogramCounter() { 93 mCounter.clear(); 94 } 95 96 public void setCounter(HashMap<String, HashMap<String, Integer> > counter) { 97 resetCounter(); 98 mCounter.putAll(counter); 99 } 100 101 public void resetCounter() { 102 mCounter.clear(); 103 } 104 105 public void addSample(String className, String featureValue) { 106 HashMap<String, Integer> classCounts; 107 108 if (!mCounter.containsKey(featureValue)) { 109 classCounts = new HashMap<String, Integer>(); 110 mCounter.put(featureValue, classCounts); 111 } else { 112 classCounts = mCounter.get(featureValue); 113 } 114 int count = (classCounts.containsKey(className)) ? 115 classCounts.get(className) + 1 : 1; 116 classCounts.put(className, count); 117 } 118 119 public HashMap<String, Double> getClassScores(String featureValue) { 120 HashMap<String, Double> classScores = new HashMap<String, Double>(); 121 122 if (mCounter.containsKey(featureValue)) { 123 int totalCount = 0; 124 for(Map.Entry<String, Integer> entry : 125 mCounter.get(featureValue).entrySet()) { 126 String app = entry.getKey(); 127 int count = entry.getValue(); 128 129 // For apps with counts less than or equal to one, we treated 130 // those as having count one. Hence their score, i.e. log(count) 131 // would be zero. classScroes stores only apps with non-zero scores. 132 // Note that totalCount also neglect app with single occurrence. 133 if (count > 1) { 134 double score = Math.log((double) count); 135 classScores.put(app, score); 136 totalCount += count; 137 } 138 } 139 if (totalCount < MINIMAL_FEATURE_VALUE_COUNTS) { 140 classScores.clear(); 141 } 142 } 143 return classScores; 144 } 145 146 public byte[] getModel() { 147 try { 148 ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); 149 ObjectOutputStream objStream = new ObjectOutputStream(byteStream); 150 synchronized(mCounter) { 151 objStream.writeObject(mCounter); 152 } 153 byte[] bytes = byteStream.toByteArray(); 154 return bytes; 155 } catch (IOException e) { 156 throw new RuntimeException("Can't get model"); 157 } 158 } 159 160 public boolean setModel(final byte[] modelData) { 161 mCounter.clear(); 162 HashMap<String, HashMap<String, Integer> > model; 163 164 try { 165 ByteArrayInputStream input = new ByteArrayInputStream(modelData); 166 ObjectInputStream objStream = new ObjectInputStream(input); 167 model = (HashMap<String, HashMap<String, Integer> >) objStream.readObject(); 168 } catch (IOException e) { 169 throw new RuntimeException("Can't load model"); 170 } catch (ClassNotFoundException e) { 171 throw new RuntimeException("Learning class not found"); 172 } 173 174 synchronized(mCounter) { 175 mCounter.putAll(model); 176 } 177 178 return true; 179 } 180 181 182 public HashMap<String, HashMap<String, Integer> > getCounter() { 183 return mCounter; 184 } 185 186 public String toString() { 187 String result = ""; 188 for (Map.Entry<String, HashMap<String, Integer> > entry : 189 mCounter.entrySet()) { 190 result += "{ " + entry.getKey() + " : " + 191 entry.getValue().toString() + " }"; 192 } 193 return result; 194 } 195 } 196 197 /* 198 * Given a map of feature name -value pairs returns topK mostly likely apps to 199 * be launched with corresponding likelihoods. If topK is set zero, it will return 200 * the whole list. 201 */ 202 public List<Map.Entry<String, Double> > findTopClasses(Map<String, String> features, int topK) { 203 // Most sophisticated function in this class 204 HashMap<String, Double> appScores = new HashMap<String, Double>(); 205 int validFeatureCount = 0; 206 207 // compute all app scores 208 for (Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) { 209 String featureName = entry.getKey(); 210 HistogramCounter counter = entry.getValue(); 211 212 if (features.containsKey(featureName)) { 213 String featureValue = features.get(featureName); 214 HashMap<String, Double> scoreMap = counter.getClassScores(featureValue); 215 216 if (scoreMap.isEmpty()) { 217 continue; 218 } 219 validFeatureCount++; 220 221 for (Map.Entry<String, Double> item : scoreMap.entrySet()) { 222 String appName = item.getKey(); 223 double appScore = item.getValue(); 224 if (appScores.containsKey(appName)) { 225 appScore += appScores.get(appName); 226 } 227 appScores.put(appName, appScore); 228 } 229 } 230 } 231 232 HashMap<String, Double> appCandidates = new HashMap<String, Double>(); 233 for (Map.Entry<String, Double> entry : appScores.entrySet()) { 234 String appName = entry.getKey(); 235 if (mBlacklist.contains(appName)) { 236 Log.i(TAG, appName + " is in blacklist"); 237 continue; 238 } 239 if (!mClassCounts.containsKey(appName)) { 240 throw new RuntimeException("class count error!"); 241 } 242 int appCount = mClassCounts.get(appName); 243 if (appCount < MINIMAL_APP_APPEARANCE_COUNTS) { 244 Log.i(TAG, appName + " doesn't have enough counts"); 245 continue; 246 } 247 248 double appScore = entry.getValue(); 249 double appPrior = Math.log((double) appCount); 250 appCandidates.put(appName, 251 appScore - appPrior * (validFeatureCount - PRIOR_K_VALUE)); 252 } 253 254 // sort app scores 255 List<Map.Entry<String, Double> > appList = 256 new ArrayList<Map.Entry<String, Double> >(appCandidates.size()); 257 appList.addAll(appCandidates.entrySet()); 258 Collections.sort(appList, new Comparator<Map.Entry<String, Double> >() { 259 public int compare(Map.Entry<String, Double> o1, 260 Map.Entry<String, Double> o2) { 261 return o2.getValue().compareTo(o1.getValue()); 262 } 263 }); 264 265 if (topK == 0) { 266 topK = appList.size(); 267 } 268 return appList.subList(0, Math.min(topK, appList.size())); 269 } 270 271 /* 272 * Add a new observation of given sample id and features to the histograms 273 */ 274 public void addSample(String sampleId, Map<String, String> features) { 275 for (Map.Entry<String, String> entry : features.entrySet()) { 276 String featureName = entry.getKey(); 277 String featureValue = entry.getValue(); 278 279 useFeature(featureName); 280 HistogramCounter counter = mPredictor.get(featureName); 281 counter.addSample(sampleId, featureValue); 282 } 283 284 int sampleCount = (mClassCounts.containsKey(sampleId)) ? 285 mClassCounts.get(sampleId) + 1 : 1; 286 mClassCounts.put(sampleId, sampleCount); 287 } 288 289 /* 290 * reset predictor to a empty model 291 */ 292 public void resetPredictor() { 293 // TODO: not sure this step would reduce memory waste 294 for (HistogramCounter counter : mPredictor.values()) { 295 counter.resetCounter(); 296 } 297 mPredictor.clear(); 298 mClassCounts.clear(); 299 } 300 301 /* 302 * convert the prediction model into a byte array 303 */ 304 public byte[] getModel() { 305 // TODO: convert model to a more memory efficient data structure. 306 HashMap<String, HashMap<String, HashMap<String, Integer > > > model = 307 new HashMap<String, HashMap<String, HashMap<String, Integer > > >(); 308 for(Map.Entry<String, HistogramCounter> entry : mPredictor.entrySet()) { 309 model.put(entry.getKey(), entry.getValue().getCounter()); 310 } 311 312 try { 313 ByteArrayOutputStream byteStream = new ByteArrayOutputStream(); 314 ObjectOutputStream objStream = new ObjectOutputStream(byteStream); 315 objStream.writeObject(model); 316 byte[] bytes = byteStream.toByteArray(); 317 return bytes; 318 } catch (IOException e) { 319 throw new RuntimeException("Can't get model"); 320 } 321 } 322 323 /* 324 * set the prediction model from a model data in the format of byte array 325 */ 326 public boolean setModel(final byte[] modelData) { 327 HashMap<String, HashMap<String, HashMap<String, Integer > > > model; 328 329 try { 330 ByteArrayInputStream input = new ByteArrayInputStream(modelData); 331 ObjectInputStream objStream = new ObjectInputStream(input); 332 model = (HashMap<String, HashMap<String, HashMap<String, Integer > > >) 333 objStream.readObject(); 334 } catch (IOException e) { 335 throw new RuntimeException("Can't load model"); 336 } catch (ClassNotFoundException e) { 337 throw new RuntimeException("Learning class not found"); 338 } 339 340 resetPredictor(); 341 for (Map.Entry<String, HashMap<String, HashMap<String, Integer> > > entry : 342 model.entrySet()) { 343 useFeature(entry.getKey()); 344 mPredictor.get(entry.getKey()).setCounter(entry.getValue()); 345 } 346 347 // TODO: this is a temporary fix for now 348 loadClassCounter(); 349 350 return true; 351 } 352 353 private void loadClassCounter() { 354 String TIME_OF_WEEK = "Time of Week"; 355 356 if (!mPredictor.containsKey(TIME_OF_WEEK)) { 357 throw new RuntimeException("Precition model error: missing Time of Week!"); 358 } 359 360 HashMap<String, HashMap<String, Integer> > counter = 361 mPredictor.get(TIME_OF_WEEK).getCounter(); 362 363 mClassCounts.clear(); 364 for (HashMap<String, Integer> map : counter.values()) { 365 for (Map.Entry<String, Integer> entry : map.entrySet()) { 366 int classCount = entry.getValue(); 367 String className = entry.getKey(); 368 // mTotalClassCount += classCount; 369 370 if (mClassCounts.containsKey(className)) { 371 classCount += mClassCounts.get(className); 372 } 373 mClassCounts.put(className, classCount); 374 } 375 } 376 Log.i(TAG, "class counts: " + mClassCounts); 377 } 378 379 private void useFeature(String featureName) { 380 if (!mPredictor.containsKey(featureName)) { 381 mPredictor.put(featureName, new HistogramCounter()); 382 } 383 } 384 } 385