1 /* 2 * Copyright (C) 2011 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 18 package android.bordeaux.learning; 19 20 import android.util.Log; 21 22 import java.io.Serializable; 23 import java.util.List; 24 import java.util.Arrays; 25 import java.util.ArrayList; 26 import java.util.HashMap; 27 import java.util.Map; 28 29 /** 30 * Stochastic Linear Ranker, learns how to rank a sample. The learned rank score 31 * can be used to compare samples. 32 * This java class wraps the native StochasticLinearRanker class. 33 * To update the ranker, call updateClassifier with two samples, with the first 34 * one having higher rank than the second one. 35 * To get the rank score of the sample call scoreSample. 36 * TODO: adding more interfaces for changing the learning parameters 37 */ 38 public class StochasticLinearRanker { 39 String TAG = "StochasticLinearRanker"; 40 public static int VAR_NUM = 14; 41 static public class Model implements Serializable { 42 public HashMap<String, Float> weights = new HashMap<String, Float>(); 43 public float weightNormalizer = 1; 44 public HashMap<String, String> parameters = new HashMap<String, String>(); 45 } 46 47 /** 48 * Initializing a ranker 49 */ 50 public StochasticLinearRanker() { 51 mNativeClassifier = initNativeClassifier(); 52 } 53 54 /** 55 * Reset the ranker 56 */ 57 public void resetRanker(){ 58 deleteNativeClassifier(mNativeClassifier); 59 mNativeClassifier = initNativeClassifier(); 60 } 61 62 /** 63 * Train the ranker with a pair of samples. A sample, a pair of arrays of 64 * keys and values. The first sample should have higher rank than the second 65 * one. 66 */ 67 public boolean updateClassifier(String[] keys_positive, 68 float[] values_positive, 69 String[] keys_negative, 70 float[] values_negative) { 71 return nativeUpdateClassifier(keys_positive, values_positive, 72 keys_negative, values_negative, 73 mNativeClassifier); 74 } 75 76 /** 77 * Get the rank score of the sample, a sample is a list of key, value pairs. 78 */ 79 public float scoreSample(String[] keys, float[] values) { 80 return nativeScoreSample(keys, values, mNativeClassifier); 81 } 82 83 /** 84 * Get the current model and parameters of ranker 85 */ 86 public Model getUModel(){ 87 Model slrModel = new Model(); 88 int len = nativeGetLengthClassifier(mNativeClassifier); 89 String[] wKeys = new String[len]; 90 float[] wValues = new float[len]; 91 float wNormalizer = 1; 92 nativeGetWeightClassifier(wKeys, wValues, wNormalizer, mNativeClassifier); 93 slrModel.weightNormalizer = wNormalizer; 94 for (int i=0; i< wKeys.length ; i++) 95 slrModel.weights.put(wKeys[i], wValues[i]); 96 97 String[] paramKeys = new String[VAR_NUM]; 98 String[] paramValues = new String[VAR_NUM]; 99 nativeGetParameterClassifier(paramKeys, paramValues, mNativeClassifier); 100 for (int i=0; i< paramKeys.length ; i++) 101 slrModel.parameters.put(paramKeys[i], paramValues[i]); 102 return slrModel; 103 } 104 105 /** 106 * load the given model and parameters to the ranker 107 */ 108 public boolean loadModel(Model model) { 109 String[] wKeys = new String[model.weights.size()]; 110 float[] wValues = new float[model.weights.size()]; 111 int i = 0 ; 112 for (Map.Entry<String, Float> e : model.weights.entrySet()){ 113 wKeys[i] = e.getKey(); 114 wValues[i] = e.getValue(); 115 i++; 116 } 117 boolean res = setModelWeights(wKeys, wValues, model.weightNormalizer); 118 if (!res) 119 return false; 120 121 for (Map.Entry<String, String> e : model.parameters.entrySet()){ 122 res = setModelParameter(e.getKey(), e.getValue()); 123 if (!res) 124 return false; 125 } 126 return res; 127 } 128 129 public boolean setModelWeights(String[] keys, float [] values, float normalizer){ 130 return nativeSetWeightClassifier(keys, values, normalizer, mNativeClassifier); 131 } 132 133 public boolean setModelParameter(String key, String value){ 134 boolean res = nativeSetParameterClassifier(key, value, mNativeClassifier); 135 return res; 136 } 137 138 /** 139 * Print a model for debugging 140 */ 141 public void print(Model model){ 142 String Sw = ""; 143 String Sp = ""; 144 for (Map.Entry<String, Float> e : model.weights.entrySet()) 145 Sw = Sw + "<" + e.getKey() + "," + e.getValue() + "> "; 146 for (Map.Entry<String, String> e : model.parameters.entrySet()) 147 Sp = Sp + "<" + e.getKey() + "," + e.getValue() + "> "; 148 Log.i(TAG, "Weights are " + Sw); 149 Log.i(TAG, "Normalizer is " + model.weightNormalizer); 150 Log.i(TAG, "Parameters are " + Sp); 151 } 152 153 @Override 154 protected void finalize() throws Throwable { 155 deleteNativeClassifier(mNativeClassifier); 156 } 157 158 static { 159 System.loadLibrary("bordeaux"); 160 } 161 162 private int mNativeClassifier; 163 164 /* 165 * The following methods are the java stubs for the jni implementations. 166 */ 167 private native int initNativeClassifier(); 168 169 private native void deleteNativeClassifier(int classifierPtr); 170 171 private native boolean nativeUpdateClassifier( 172 String[] keys_positive, 173 float[] values_positive, 174 String[] keys_negative, 175 float[] values_negative, 176 int classifierPtr); 177 178 private native float nativeScoreSample(String[] keys, float[] values, int classifierPtr); 179 180 private native void nativeGetWeightClassifier(String [] keys, float[] values, float normalizer, 181 int classifierPtr); 182 183 private native void nativeGetParameterClassifier(String [] keys, String[] values, 184 int classifierPtr); 185 186 private native int nativeGetLengthClassifier(int classifierPtr); 187 188 private native boolean nativeSetWeightClassifier(String [] keys, float[] values, 189 float normalizer, int classifierPtr); 190 191 private native boolean nativeSetParameterClassifier(String key, String value, 192 int classifierPtr); 193 } 194