Home | History | Annotate | Download | only in libtextclassifier
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Simple JNI wrapper for the SmartSelection library.
     18 
     19 #include "textclassifier_jni.h"
     20 
     21 #include <jni.h>
     22 #include <vector>
     23 
     24 #include "lang_id/lang-id.h"
     25 #include "smartselect/text-classification-model.h"
     26 
     27 using libtextclassifier::TextClassificationModel;
     28 using libtextclassifier::ModelOptions;
     29 using libtextclassifier::nlp_core::lang_id::LangId;
     30 
     31 namespace {
     32 
     33 bool JStringToUtf8String(JNIEnv* env, const jstring& jstr,
     34                          std::string* result) {
     35   if (jstr == nullptr) {
     36     *result = std::string();
     37     return false;
     38   }
     39 
     40   jclass string_class = env->FindClass("java/lang/String");
     41   jmethodID get_bytes_id =
     42       env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B");
     43 
     44   jstring encoding = env->NewStringUTF("UTF-8");
     45   jbyteArray array = reinterpret_cast<jbyteArray>(
     46       env->CallObjectMethod(jstr, get_bytes_id, encoding));
     47 
     48   jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE);
     49   int length = env->GetArrayLength(array);
     50 
     51   *result = std::string(reinterpret_cast<char*>(array_bytes), length);
     52 
     53   // Release the array.
     54   env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT);
     55   env->DeleteLocalRef(array);
     56   env->DeleteLocalRef(string_class);
     57   env->DeleteLocalRef(encoding);
     58 
     59   return true;
     60 }
     61 
     62 std::string ToStlString(JNIEnv* env, const jstring& str) {
     63   std::string result;
     64   JStringToUtf8String(env, str, &result);
     65   return result;
     66 }
     67 
     68 jobjectArray ScoredStringsToJObjectArray(
     69     JNIEnv* env, const std::string& result_class_name,
     70     const std::vector<std::pair<std::string, float>>& classification_result) {
     71   jclass result_class = env->FindClass(result_class_name.c_str());
     72   jmethodID result_class_constructor =
     73       env->GetMethodID(result_class, "<init>", "(Ljava/lang/String;F)V");
     74 
     75   jobjectArray results =
     76       env->NewObjectArray(classification_result.size(), result_class, nullptr);
     77 
     78   for (int i = 0; i < classification_result.size(); i++) {
     79     jstring row_string =
     80         env->NewStringUTF(classification_result[i].first.c_str());
     81     jobject result =
     82         env->NewObject(result_class, result_class_constructor, row_string,
     83                        static_cast<jfloat>(classification_result[i].second));
     84     env->SetObjectArrayElement(results, i, result);
     85     env->DeleteLocalRef(result);
     86   }
     87   env->DeleteLocalRef(result_class);
     88   return results;
     89 }
     90 
     91 }  // namespace
     92 
     93 namespace libtextclassifier {
     94 
     95 using libtextclassifier::CodepointSpan;
     96 
     97 namespace {
     98 
     99 CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
    100                                     CodepointSpan orig_indices,
    101                                     bool from_utf8) {
    102   const libtextclassifier::UnicodeText unicode_str =
    103       libtextclassifier::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);
    104 
    105   int unicode_index = 0;
    106   int bmp_index = 0;
    107 
    108   const int* source_index;
    109   const int* target_index;
    110   if (from_utf8) {
    111     source_index = &unicode_index;
    112     target_index = &bmp_index;
    113   } else {
    114     source_index = &bmp_index;
    115     target_index = &unicode_index;
    116   }
    117 
    118   CodepointSpan result{-1, -1};
    119   std::function<void()> assign_indices_fn = [&result, &orig_indices,
    120                                              &source_index, &target_index]() {
    121     if (orig_indices.first == *source_index) {
    122       result.first = *target_index;
    123     }
    124 
    125     if (orig_indices.second == *source_index) {
    126       result.second = *target_index;
    127     }
    128   };
    129 
    130   for (auto it = unicode_str.begin(); it != unicode_str.end();
    131        ++it, ++unicode_index, ++bmp_index) {
    132     assign_indices_fn();
    133 
    134     // There is 1 extra character in the input for each UTF8 character > 0xFFFF.
    135     if (*it > 0xFFFF) {
    136       ++bmp_index;
    137     }
    138   }
    139   assign_indices_fn();
    140 
    141   return result;
    142 }
    143 
    144 }  // namespace
    145 
    146 CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str,
    147                                       CodepointSpan orig_indices) {
    148   return ConvertIndicesBMPUTF8(utf8_str, orig_indices, /*from_utf8=*/false);
    149 }
    150 
    151 CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
    152                                       CodepointSpan orig_indices) {
    153   return ConvertIndicesBMPUTF8(utf8_str, orig_indices, /*from_utf8=*/true);
    154 }
    155 
    156 }  // namespace libtextclassifier
    157 
    158 using libtextclassifier::ConvertIndicesUTF8ToBMP;
    159 using libtextclassifier::ConvertIndicesBMPToUTF8;
    160 using libtextclassifier::CodepointSpan;
    161 
    162 JNIEXPORT jlong JNICALL
    163 Java_android_view_textclassifier_SmartSelection_nativeNew(JNIEnv* env,
    164                                                           jobject thiz,
    165                                                           jint fd) {
    166   TextClassificationModel* model = new TextClassificationModel(fd);
    167   return reinterpret_cast<jlong>(model);
    168 }
    169 
    170 JNIEXPORT jintArray JNICALL
    171 Java_android_view_textclassifier_SmartSelection_nativeSuggest(
    172     JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
    173     jint selection_end) {
    174   TextClassificationModel* model =
    175       reinterpret_cast<TextClassificationModel*>(ptr);
    176 
    177   const std::string context_utf8 = ToStlString(env, context);
    178   CodepointSpan input_indices =
    179       ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
    180   CodepointSpan selection =
    181       model->SuggestSelection(context_utf8, input_indices);
    182   selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
    183 
    184   jintArray result = env->NewIntArray(2);
    185   env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection)));
    186   env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection)));
    187   return result;
    188 }
    189 
    190 JNIEXPORT jobjectArray JNICALL
    191 Java_android_view_textclassifier_SmartSelection_nativeClassifyText(
    192     JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
    193     jint selection_end, jint input_flags) {
    194   TextClassificationModel* ff_model =
    195       reinterpret_cast<TextClassificationModel*>(ptr);
    196   const std::vector<std::pair<std::string, float>> classification_result =
    197       ff_model->ClassifyText(ToStlString(env, context),
    198                              {selection_begin, selection_end}, input_flags);
    199 
    200   return ScoredStringsToJObjectArray(
    201       env, "android/view/textclassifier/SmartSelection$ClassificationResult",
    202       classification_result);
    203 }
    204 
    205 JNIEXPORT void JNICALL
    206 Java_android_view_textclassifier_SmartSelection_nativeClose(JNIEnv* env,
    207                                                             jobject thiz,
    208                                                             jlong ptr) {
    209   TextClassificationModel* model =
    210       reinterpret_cast<TextClassificationModel*>(ptr);
    211   delete model;
    212 }
    213 
    214 JNIEXPORT jlong JNICALL Java_android_view_textclassifier_LangId_nativeNew(
    215     JNIEnv* env, jobject thiz, jint fd) {
    216   return reinterpret_cast<jlong>(new LangId(fd));
    217 }
    218 
    219 JNIEXPORT jstring JNICALL
    220 Java_android_view_textclassifier_SmartSelection_nativeGetLanguage(JNIEnv* env,
    221                                                                   jobject clazz,
    222                                                                   jint fd) {
    223   ModelOptions model_options;
    224   if (ReadSelectionModelOptions(fd, &model_options)) {
    225     return env->NewStringUTF(model_options.language().c_str());
    226   } else {
    227     return env->NewStringUTF("UNK");
    228   }
    229 }
    230 
    231 JNIEXPORT jint JNICALL
    232 Java_android_view_textclassifier_SmartSelection_nativeGetVersion(JNIEnv* env,
    233                                                                  jobject clazz,
    234                                                                  jint fd) {
    235   ModelOptions model_options;
    236   if (ReadSelectionModelOptions(fd, &model_options)) {
    237     return model_options.version();
    238   } else {
    239     return -1;
    240   }
    241 }
    242 
    243 JNIEXPORT jobjectArray JNICALL
    244 Java_android_view_textclassifier_LangId_nativeFindLanguages(JNIEnv* env,
    245                                                             jobject thiz,
    246                                                             jlong ptr,
    247                                                             jstring text) {
    248   LangId* lang_id = reinterpret_cast<LangId*>(ptr);
    249   const std::vector<std::pair<std::string, float>> scored_languages =
    250       lang_id->FindLanguages(ToStlString(env, text));
    251 
    252   return ScoredStringsToJObjectArray(
    253       env, "android/view/textclassifier/LangId$ClassificationResult",
    254       scored_languages);
    255 }
    256 
    257 JNIEXPORT void JNICALL Java_android_view_textclassifier_LangId_nativeClose(
    258     JNIEnv* env, jobject thiz, jlong ptr) {
    259   LangId* lang_id = reinterpret_cast<LangId*>(ptr);
    260   delete lang_id;
    261 }
    262 
    263 JNIEXPORT int JNICALL Java_android_view_textclassifier_LangId_nativeGetVersion(
    264     JNIEnv* env, jobject clazz, jint fd) {
    265   std::unique_ptr<LangId> lang_id(new LangId(fd));
    266   return lang_id->version();
    267 }
    268