Home | History | Annotate | Download | only in jni
      1 /*
      2  * Copyright (C) 2011 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "jni/jni_multiclass_pa.h"
     18 #include "native/multiclass_pa.h"
     19 
     20 #include <vector>
     21 
     22 using learningfw::MulticlassPA;
     23 using std::vector;
     24 using std::pair;
     25 
     26 void CreateIndexValuePairs(const int* indices, const float* values,
     27                            const int length, vector<pair<int, float> >* pairs) {
     28   pairs->clear();
     29 
     30   for (int i = 0; i < length; ++i) {
     31     pair<int, float> new_pair(indices[i], values[i]);
     32     pairs->push_back(new_pair);
     33   }
     34 }
     35 
     36 jint Java_android_bordeaux_learning_MulticlassPA_initNativeClassifier(JNIEnv* env,
     37                                                        jobject thiz,
     38                                                        jint num_classes,
     39                                                        jint num_dims,
     40                                                        jfloat aggressiveness) {
     41   MulticlassPA* classifier = new MulticlassPA(num_classes,
     42                                               num_dims,
     43                                               aggressiveness);
     44   return ((jint) classifier);
     45 }
     46 
     47 
     48 jboolean Java_android_bordeaux_learning_MulticlassPA_deleteNativeClassifier(JNIEnv* env,
     49                                                              jobject thiz,
     50                                                              jint paPtr) {
     51   MulticlassPA* classifier = (MulticlassPA*) paPtr;
     52   delete classifier;
     53   return JNI_TRUE;
     54 }
     55 
     56 jboolean Java_android_bordeaux_learning_MulticlassPA_nativeSparseTrainOneExample(JNIEnv* env,
     57                                                                   jobject thiz,
     58                                                                   jintArray index_array,
     59                                                                   jfloatArray value_array,
     60                                                                   jint target,
     61                                                                   jint paPtr) {
     62   MulticlassPA* classifier = (MulticlassPA*) paPtr;
     63 
     64   if (classifier && index_array && value_array) {
     65 
     66     jfloat* values = env->GetFloatArrayElements(value_array, NULL);
     67     jint* indices = env->GetIntArrayElements(index_array, NULL);
     68     const int value_len = env->GetArrayLength(value_array);
     69     const int index_len = env->GetArrayLength(index_array);
     70 
     71     if (values && indices && value_len == index_len) {
     72       vector<pair<int, float> > inputs;
     73 
     74       CreateIndexValuePairs(indices, values, value_len, &inputs);
     75       classifier->SparseTrainOneExample(inputs, target);
     76       env->ReleaseIntArrayElements(index_array, indices, JNI_ABORT);
     77       env->ReleaseFloatArrayElements(value_array, values, JNI_ABORT);
     78 
     79       return JNI_TRUE;
     80     }
     81     env->ReleaseIntArrayElements(index_array, indices, JNI_ABORT);
     82     env->ReleaseFloatArrayElements(value_array, values, JNI_ABORT);
     83   }
     84 
     85   return JNI_FALSE;
     86 }
     87 
     88 
     89 jint Java_android_bordeaux_learning_MulticlassPA_nativeSparseGetClass(JNIEnv* env,
     90                                                        jobject thiz,
     91                                                        jintArray index_array,
     92                                                        jfloatArray value_array,
     93                                                        jint paPtr) {
     94 
     95   MulticlassPA* classifier = (MulticlassPA*) paPtr;
     96 
     97   if (classifier && index_array && value_array) {
     98 
     99     jfloat* values = env->GetFloatArrayElements(value_array, NULL);
    100     jint* indices = env->GetIntArrayElements(index_array, NULL);
    101     const int value_len = env->GetArrayLength(value_array);
    102     const int index_len = env->GetArrayLength(index_array);
    103 
    104     if (values && indices && value_len == index_len) {
    105       vector<pair<int, float> > inputs;
    106       CreateIndexValuePairs(indices, values, value_len, &inputs);
    107       env->ReleaseIntArrayElements(index_array, indices, JNI_ABORT);
    108       env->ReleaseFloatArrayElements(value_array, values, JNI_ABORT);
    109       return classifier->SparseGetClass(inputs);
    110     }
    111     env->ReleaseIntArrayElements(index_array, indices, JNI_ABORT);
    112     env->ReleaseFloatArrayElements(value_array, values, JNI_ABORT);
    113   }
    114 
    115   return -1;
    116 }
    117