Home | History | Annotate | Download | only in learning
      1 /*
      2  * Copyright (C) 2011 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 
     18 package android.bordeaux.learning;
     19 
     20 import android.util.Log;
     21 
     22 import java.io.Serializable;
     23 import java.util.List;
     24 import java.util.Arrays;
     25 import java.util.ArrayList;
     26 import java.util.HashMap;
     27 import java.util.Map;
     28 
     29 /**
     30  * Stochastic Linear Ranker, learns how to rank a sample. The learned rank score
     31  * can be used to compare samples.
     32  * This java class wraps the native StochasticLinearRanker class.
     33  * To update the ranker, call updateClassifier with two samples, with the first
     34  * one having higher rank than the second one.
     35  * To get the rank score of the sample call scoreSample.
     36  *  TODO: adding more interfaces for changing the learning parameters
     37  */
     38 public class StochasticLinearRanker {
     39     String TAG = "StochasticLinearRanker";
     40     public static int VAR_NUM = 14;
     41     static public class Model implements Serializable {
     42         public HashMap<String, Float> weights = new HashMap<String, Float>();
     43         public float weightNormalizer = 1;
     44         public HashMap<String, String> parameters = new HashMap<String, String>();
     45     }
     46 
     47     /**
     48      * Initializing a ranker
     49      */
     50     public StochasticLinearRanker() {
     51         mNativeClassifier = initNativeClassifier();
     52     }
     53 
     54     /**
     55      * Reset the ranker
     56      */
     57     public void resetRanker(){
     58         deleteNativeClassifier(mNativeClassifier);
     59         mNativeClassifier = initNativeClassifier();
     60     }
     61 
     62     /**
     63      * Train the ranker with a pair of samples. A sample,  a pair of arrays of
     64      * keys and values. The first sample should have higher rank than the second
     65      * one.
     66      */
     67     public boolean updateClassifier(String[] keys_positive,
     68                                     float[] values_positive,
     69                                     String[] keys_negative,
     70                                     float[] values_negative) {
     71         return nativeUpdateClassifier(keys_positive, values_positive,
     72                                       keys_negative, values_negative,
     73                                       mNativeClassifier);
     74     }
     75 
     76     /**
     77      * Get the rank score of the sample, a sample is a list of key, value pairs.
     78      */
     79     public float scoreSample(String[] keys, float[] values) {
     80         return nativeScoreSample(keys, values, mNativeClassifier);
     81     }
     82 
     83     /**
     84      * Get the current model and parameters of ranker
     85      */
     86     public Model getUModel(){
     87         Model slrModel = new Model();
     88         int len = nativeGetLengthClassifier(mNativeClassifier);
     89         String[] wKeys = new String[len];
     90         float[] wValues = new float[len];
     91         float wNormalizer = 1;
     92         nativeGetWeightClassifier(wKeys, wValues, wNormalizer, mNativeClassifier);
     93         slrModel.weightNormalizer = wNormalizer;
     94         for (int  i=0; i< wKeys.length ; i++)
     95             slrModel.weights.put(wKeys[i], wValues[i]);
     96 
     97         String[] paramKeys = new String[VAR_NUM];
     98         String[] paramValues = new String[VAR_NUM];
     99         nativeGetParameterClassifier(paramKeys, paramValues, mNativeClassifier);
    100         for (int  i=0; i< paramKeys.length ; i++)
    101             slrModel.parameters.put(paramKeys[i], paramValues[i]);
    102         return slrModel;
    103     }
    104 
    105     /**
    106      * load the given model and parameters to the ranker
    107      */
    108     public boolean loadModel(Model model) {
    109         String[] wKeys = new String[model.weights.size()];
    110         float[] wValues = new float[model.weights.size()];
    111         int i = 0 ;
    112         for (Map.Entry<String, Float> e : model.weights.entrySet()){
    113             wKeys[i] = e.getKey();
    114             wValues[i] = e.getValue();
    115             i++;
    116         }
    117         boolean res = setModelWeights(wKeys, wValues, model.weightNormalizer);
    118         if (!res)
    119             return false;
    120 
    121         for (Map.Entry<String, String> e : model.parameters.entrySet()){
    122             res = setModelParameter(e.getKey(), e.getValue());
    123             if (!res)
    124                 return false;
    125         }
    126         return res;
    127     }
    128 
    129     public boolean setModelWeights(String[] keys, float [] values, float normalizer){
    130         return nativeSetWeightClassifier(keys, values, normalizer, mNativeClassifier);
    131     }
    132 
    133     public boolean setModelParameter(String key, String value){
    134         boolean res = nativeSetParameterClassifier(key, value, mNativeClassifier);
    135         return res;
    136     }
    137 
    138     /**
    139      * Print a model for debugging
    140      */
    141     public void print(Model model){
    142         String Sw = "";
    143         String Sp = "";
    144         for (Map.Entry<String, Float> e : model.weights.entrySet())
    145             Sw = Sw + "<" + e.getKey() + "," + e.getValue() + "> ";
    146         for (Map.Entry<String, String> e : model.parameters.entrySet())
    147             Sp = Sp + "<" + e.getKey() + "," + e.getValue() + "> ";
    148         Log.i(TAG, "Weights are " + Sw);
    149         Log.i(TAG, "Normalizer is " + model.weightNormalizer);
    150         Log.i(TAG, "Parameters are " + Sp);
    151     }
    152 
    153     @Override
    154     protected void finalize() throws Throwable {
    155         deleteNativeClassifier(mNativeClassifier);
    156     }
    157 
    158     static {
    159         System.loadLibrary("bordeaux");
    160     }
    161 
    162     private int mNativeClassifier;
    163 
    164     /*
    165      * The following methods are the java stubs for the jni implementations.
    166      */
    167     private native int initNativeClassifier();
    168 
    169     private native void deleteNativeClassifier(int classifierPtr);
    170 
    171     private native boolean nativeUpdateClassifier(
    172             String[] keys_positive,
    173             float[] values_positive,
    174             String[] keys_negative,
    175             float[] values_negative,
    176             int classifierPtr);
    177 
    178     private native float nativeScoreSample(String[] keys, float[] values, int classifierPtr);
    179 
    180     private native void nativeGetWeightClassifier(String [] keys, float[] values, float normalizer,
    181                                                   int classifierPtr);
    182 
    183     private native void nativeGetParameterClassifier(String [] keys, String[] values,
    184                                                   int classifierPtr);
    185 
    186     private native int nativeGetLengthClassifier(int classifierPtr);
    187 
    188     private native boolean nativeSetWeightClassifier(String [] keys, float[] values,
    189                                                      float normalizer, int classifierPtr);
    190 
    191     private native boolean nativeSetParameterClassifier(String key, String value,
    192                                                         int classifierPtr);
    193 }
    194