Home | History | Annotate | Download | only in resolver
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package android.ext.services.resolver;
     18 
     19 import android.content.Context;
     20 import android.content.Intent;
     21 import android.content.SharedPreferences;
     22 import android.os.Environment;
     23 import android.os.IBinder;
     24 import android.os.storage.StorageManager;
     25 import android.service.resolver.ResolverRankerService;
     26 import android.service.resolver.ResolverTarget;
     27 import android.util.ArrayMap;
     28 import android.util.Log;
     29 
     30 import java.io.File;
     31 import java.util.Collection;
     32 import java.util.List;
     33 import java.util.Map;
     34 
     35 /**
     36  * A Logistic Regression based {@link android.service.resolver.ResolverRankerService}, to be used
     37  * in {@link ResolverComparator}.
     38  */
     39 public final class LRResolverRankerService extends ResolverRankerService {
     40     private static final String TAG = "LRResolverRankerService";
     41 
     42     private static final boolean DEBUG = false;
     43 
     44     private static final String PARAM_SHARED_PREF_NAME = "resolver_ranker_params";
     45     private static final String BIAS_PREF_KEY = "bias";
     46     private static final String VERSION_PREF_KEY = "version";
     47 
     48     private static final String LAUNCH_SCORE = "launch";
     49     private static final String TIME_SPENT_SCORE = "timeSpent";
     50     private static final String RECENCY_SCORE = "recency";
     51     private static final String CHOOSER_SCORE = "chooser";
     52 
     53     // parameters for a pre-trained model, to initialize the app ranker. When updating the
     54     // pre-trained model, please update these params, as well as initModel().
     55     private static final int CURRENT_VERSION = 1;
     56     private static final float LEARNING_RATE = 0.0001f;
     57     private static final float REGULARIZER_PARAM = 0.0001f;
     58 
     59     private SharedPreferences mParamSharedPref;
     60     private ArrayMap<String, Float> mFeatureWeights;
     61     private float mBias;
     62 
     63     @Override
     64     public IBinder onBind(Intent intent) {
     65         initModel();
     66         return super.onBind(intent);
     67     }
     68 
     69     @Override
     70     public void onPredictSharingProbabilities(List<ResolverTarget> targets) {
     71         final int size = targets.size();
     72         for (int i = 0; i < size; ++i) {
     73             ResolverTarget target = targets.get(i);
     74             ArrayMap<String, Float> features = getFeatures(target);
     75             target.setSelectProbability(predict(features));
     76         }
     77     }
     78 
     79     @Override
     80     public void onTrainRankingModel(List<ResolverTarget> targets, int selectedPosition) {
     81         final int size = targets.size();
     82         if (selectedPosition < 0 || selectedPosition >= size) {
     83             if (DEBUG) {
     84                 Log.d(TAG, "Invalid Position of Selected App " + selectedPosition);
     85             }
     86             return;
     87         }
     88         final ArrayMap<String, Float> positive = getFeatures(targets.get(selectedPosition));
     89         final float positiveProbability = targets.get(selectedPosition).getSelectProbability();
     90         final int targetSize = targets.size();
     91         for (int i = 0; i < targetSize; ++i) {
     92             if (i == selectedPosition) {
     93                 continue;
     94             }
     95             final ArrayMap<String, Float> negative = getFeatures(targets.get(i));
     96             final float negativeProbability = targets.get(i).getSelectProbability();
     97             if (negativeProbability > positiveProbability) {
     98                 update(negative, negativeProbability, false);
     99                 update(positive, positiveProbability, true);
    100             }
    101         }
    102         commitUpdate();
    103     }
    104 
    105     private void initModel() {
    106         mParamSharedPref = getParamSharedPref();
    107         mFeatureWeights = new ArrayMap<>(4);
    108         if (mParamSharedPref == null ||
    109                 mParamSharedPref.getInt(VERSION_PREF_KEY, 0) < CURRENT_VERSION) {
    110             // Initializing the app ranker to a pre-trained model. When updating the pre-trained
    111             // model, please increment CURRENT_VERSION, and update LEARNING_RATE and
    112             // REGULARIZER_PARAM.
    113             mBias = -1.6568f;
    114             mFeatureWeights.put(LAUNCH_SCORE, 2.5543f);
    115             mFeatureWeights.put(TIME_SPENT_SCORE, 2.8412f);
    116             mFeatureWeights.put(RECENCY_SCORE, 0.269f);
    117             mFeatureWeights.put(CHOOSER_SCORE, 4.2222f);
    118         } else {
    119             mBias = mParamSharedPref.getFloat(BIAS_PREF_KEY, 0.0f);
    120             mFeatureWeights.put(LAUNCH_SCORE, mParamSharedPref.getFloat(LAUNCH_SCORE, 0.0f));
    121             mFeatureWeights.put(
    122                     TIME_SPENT_SCORE, mParamSharedPref.getFloat(TIME_SPENT_SCORE, 0.0f));
    123             mFeatureWeights.put(RECENCY_SCORE, mParamSharedPref.getFloat(RECENCY_SCORE, 0.0f));
    124             mFeatureWeights.put(CHOOSER_SCORE, mParamSharedPref.getFloat(CHOOSER_SCORE, 0.0f));
    125         }
    126     }
    127 
    128     private ArrayMap<String, Float> getFeatures(ResolverTarget target) {
    129         ArrayMap<String, Float> features = new ArrayMap<>(4);
    130         features.put(RECENCY_SCORE, target.getRecencyScore());
    131         features.put(TIME_SPENT_SCORE, target.getTimeSpentScore());
    132         features.put(LAUNCH_SCORE, target.getLaunchScore());
    133         features.put(CHOOSER_SCORE, target.getChooserScore());
    134         return features;
    135     }
    136 
    137     private float predict(ArrayMap<String, Float> target) {
    138         if (target == null) {
    139             return 0.0f;
    140         }
    141         final int featureSize = target.size();
    142         float sum = 0.0f;
    143         for (int i = 0; i < featureSize; i++) {
    144             String featureName = target.keyAt(i);
    145             float weight = mFeatureWeights.getOrDefault(featureName, 0.0f);
    146             sum += weight * target.valueAt(i);
    147         }
    148         return (float) (1.0 / (1.0 + Math.exp(-mBias - sum)));
    149     }
    150 
    151     private void update(ArrayMap<String, Float> target, float predict, boolean isSelected) {
    152         if (target == null) {
    153             return;
    154         }
    155         final int featureSize = target.size();
    156         float error = isSelected ? 1.0f - predict : -predict;
    157         for (int i = 0; i < featureSize; i++) {
    158             String featureName = target.keyAt(i);
    159             float currentWeight = mFeatureWeights.getOrDefault(featureName, 0.0f);
    160             mBias += LEARNING_RATE * error;
    161             currentWeight = currentWeight - LEARNING_RATE * REGULARIZER_PARAM * currentWeight +
    162                     LEARNING_RATE * error * target.valueAt(i);
    163             mFeatureWeights.put(featureName, currentWeight);
    164         }
    165         if (DEBUG) {
    166             Log.d(TAG, "Weights: " + mFeatureWeights + " Bias: " + mBias);
    167         }
    168     }
    169 
    170     private void commitUpdate() {
    171         try {
    172             SharedPreferences.Editor editor = mParamSharedPref.edit();
    173             editor.putFloat(BIAS_PREF_KEY, mBias);
    174             final int size = mFeatureWeights.size();
    175             for (int i = 0; i < size; i++) {
    176                 editor.putFloat(mFeatureWeights.keyAt(i), mFeatureWeights.valueAt(i));
    177             }
    178             editor.putInt(VERSION_PREF_KEY, CURRENT_VERSION);
    179             editor.apply();
    180         } catch (Exception e) {
    181             Log.e(TAG, "Failed to commit update" + e);
    182         }
    183     }
    184 
    185     private SharedPreferences getParamSharedPref() {
    186         // The package info in the context isn't initialized in the way it is for normal apps,
    187         // so the standard, name-based context.getSharedPreferences doesn't work. Instead, we
    188         // build the path manually below using the same policy that appears in ContextImpl.
    189         if (DEBUG) {
    190             Log.d(TAG, "Context Package Name: " + getPackageName());
    191         }
    192         final File prefsFile = new File(new File(
    193                 Environment.getDataUserCePackageDirectory(
    194                         StorageManager.UUID_PRIVATE_INTERNAL, getUserId(), getPackageName()),
    195                 "shared_prefs"),
    196                 PARAM_SHARED_PREF_NAME + ".xml");
    197         return getSharedPreferences(prefsFile, Context.MODE_PRIVATE);
    198     }
    199 }