1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Simple JNI wrapper for the SmartSelection library. 18 19 #include "textclassifier_jni.h" 20 21 #include <jni.h> 22 #include <vector> 23 24 #include "lang_id/lang-id.h" 25 #include "smartselect/text-classification-model.h" 26 27 using libtextclassifier::TextClassificationModel; 28 using libtextclassifier::ModelOptions; 29 using libtextclassifier::nlp_core::lang_id::LangId; 30 31 namespace { 32 33 bool JStringToUtf8String(JNIEnv* env, const jstring& jstr, 34 std::string* result) { 35 if (jstr == nullptr) { 36 *result = std::string(); 37 return false; 38 } 39 40 jclass string_class = env->FindClass("java/lang/String"); 41 jmethodID get_bytes_id = 42 env->GetMethodID(string_class, "getBytes", "(Ljava/lang/String;)[B"); 43 44 jstring encoding = env->NewStringUTF("UTF-8"); 45 jbyteArray array = reinterpret_cast<jbyteArray>( 46 env->CallObjectMethod(jstr, get_bytes_id, encoding)); 47 48 jbyte* const array_bytes = env->GetByteArrayElements(array, JNI_FALSE); 49 int length = env->GetArrayLength(array); 50 51 *result = std::string(reinterpret_cast<char*>(array_bytes), length); 52 53 // Release the array. 54 env->ReleaseByteArrayElements(array, array_bytes, JNI_ABORT); 55 env->DeleteLocalRef(array); 56 env->DeleteLocalRef(string_class); 57 env->DeleteLocalRef(encoding); 58 59 return true; 60 } 61 62 std::string ToStlString(JNIEnv* env, const jstring& str) { 63 std::string result; 64 JStringToUtf8String(env, str, &result); 65 return result; 66 } 67 68 jobjectArray ScoredStringsToJObjectArray( 69 JNIEnv* env, const std::string& result_class_name, 70 const std::vector<std::pair<std::string, float>>& classification_result) { 71 jclass result_class = env->FindClass(result_class_name.c_str()); 72 jmethodID result_class_constructor = 73 env->GetMethodID(result_class, "<init>", "(Ljava/lang/String;F)V"); 74 75 jobjectArray results = 76 env->NewObjectArray(classification_result.size(), result_class, nullptr); 77 78 for (int i = 0; i < classification_result.size(); i++) { 79 jstring row_string = 80 env->NewStringUTF(classification_result[i].first.c_str()); 81 jobject result = 82 env->NewObject(result_class, result_class_constructor, row_string, 83 static_cast<jfloat>(classification_result[i].second)); 84 env->SetObjectArrayElement(results, i, result); 85 env->DeleteLocalRef(result); 86 } 87 env->DeleteLocalRef(result_class); 88 return results; 89 } 90 91 } // namespace 92 93 namespace libtextclassifier { 94 95 using libtextclassifier::CodepointSpan; 96 97 namespace { 98 99 CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str, 100 CodepointSpan orig_indices, 101 bool from_utf8) { 102 const libtextclassifier::UnicodeText unicode_str = 103 libtextclassifier::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false); 104 105 int unicode_index = 0; 106 int bmp_index = 0; 107 108 const int* source_index; 109 const int* target_index; 110 if (from_utf8) { 111 source_index = &unicode_index; 112 target_index = &bmp_index; 113 } else { 114 source_index = &bmp_index; 115 target_index = &unicode_index; 116 } 117 118 CodepointSpan result{-1, -1}; 119 std::function<void()> assign_indices_fn = [&result, &orig_indices, 120 &source_index, &target_index]() { 121 if (orig_indices.first == *source_index) { 122 result.first = *target_index; 123 } 124 125 if (orig_indices.second == *source_index) { 126 result.second = *target_index; 127 } 128 }; 129 130 for (auto it = unicode_str.begin(); it != unicode_str.end(); 131 ++it, ++unicode_index, ++bmp_index) { 132 assign_indices_fn(); 133 134 // There is 1 extra character in the input for each UTF8 character > 0xFFFF. 135 if (*it > 0xFFFF) { 136 ++bmp_index; 137 } 138 } 139 assign_indices_fn(); 140 141 return result; 142 } 143 144 } // namespace 145 146 CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str, 147 CodepointSpan orig_indices) { 148 return ConvertIndicesBMPUTF8(utf8_str, orig_indices, /*from_utf8=*/false); 149 } 150 151 CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str, 152 CodepointSpan orig_indices) { 153 return ConvertIndicesBMPUTF8(utf8_str, orig_indices, /*from_utf8=*/true); 154 } 155 156 } // namespace libtextclassifier 157 158 using libtextclassifier::ConvertIndicesUTF8ToBMP; 159 using libtextclassifier::ConvertIndicesBMPToUTF8; 160 using libtextclassifier::CodepointSpan; 161 162 JNIEXPORT jlong JNICALL 163 Java_android_view_textclassifier_SmartSelection_nativeNew(JNIEnv* env, 164 jobject thiz, 165 jint fd) { 166 TextClassificationModel* model = new TextClassificationModel(fd); 167 return reinterpret_cast<jlong>(model); 168 } 169 170 JNIEXPORT jintArray JNICALL 171 Java_android_view_textclassifier_SmartSelection_nativeSuggest( 172 JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, 173 jint selection_end) { 174 TextClassificationModel* model = 175 reinterpret_cast<TextClassificationModel*>(ptr); 176 177 const std::string context_utf8 = ToStlString(env, context); 178 CodepointSpan input_indices = 179 ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); 180 CodepointSpan selection = 181 model->SuggestSelection(context_utf8, input_indices); 182 selection = ConvertIndicesUTF8ToBMP(context_utf8, selection); 183 184 jintArray result = env->NewIntArray(2); 185 env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection))); 186 env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection))); 187 return result; 188 } 189 190 JNIEXPORT jobjectArray JNICALL 191 Java_android_view_textclassifier_SmartSelection_nativeClassifyText( 192 JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, 193 jint selection_end, jint input_flags) { 194 TextClassificationModel* ff_model = 195 reinterpret_cast<TextClassificationModel*>(ptr); 196 const std::vector<std::pair<std::string, float>> classification_result = 197 ff_model->ClassifyText(ToStlString(env, context), 198 {selection_begin, selection_end}, input_flags); 199 200 return ScoredStringsToJObjectArray( 201 env, "android/view/textclassifier/SmartSelection$ClassificationResult", 202 classification_result); 203 } 204 205 JNIEXPORT void JNICALL 206 Java_android_view_textclassifier_SmartSelection_nativeClose(JNIEnv* env, 207 jobject thiz, 208 jlong ptr) { 209 TextClassificationModel* model = 210 reinterpret_cast<TextClassificationModel*>(ptr); 211 delete model; 212 } 213 214 JNIEXPORT jlong JNICALL Java_android_view_textclassifier_LangId_nativeNew( 215 JNIEnv* env, jobject thiz, jint fd) { 216 return reinterpret_cast<jlong>(new LangId(fd)); 217 } 218 219 JNIEXPORT jstring JNICALL 220 Java_android_view_textclassifier_SmartSelection_nativeGetLanguage(JNIEnv* env, 221 jobject clazz, 222 jint fd) { 223 ModelOptions model_options; 224 if (ReadSelectionModelOptions(fd, &model_options)) { 225 return env->NewStringUTF(model_options.language().c_str()); 226 } else { 227 return env->NewStringUTF("UNK"); 228 } 229 } 230 231 JNIEXPORT jint JNICALL 232 Java_android_view_textclassifier_SmartSelection_nativeGetVersion(JNIEnv* env, 233 jobject clazz, 234 jint fd) { 235 ModelOptions model_options; 236 if (ReadSelectionModelOptions(fd, &model_options)) { 237 return model_options.version(); 238 } else { 239 return -1; 240 } 241 } 242 243 JNIEXPORT jobjectArray JNICALL 244 Java_android_view_textclassifier_LangId_nativeFindLanguages(JNIEnv* env, 245 jobject thiz, 246 jlong ptr, 247 jstring text) { 248 LangId* lang_id = reinterpret_cast<LangId*>(ptr); 249 const std::vector<std::pair<std::string, float>> scored_languages = 250 lang_id->FindLanguages(ToStlString(env, text)); 251 252 return ScoredStringsToJObjectArray( 253 env, "android/view/textclassifier/LangId$ClassificationResult", 254 scored_languages); 255 } 256 257 JNIEXPORT void JNICALL Java_android_view_textclassifier_LangId_nativeClose( 258 JNIEnv* env, jobject thiz, jlong ptr) { 259 LangId* lang_id = reinterpret_cast<LangId*>(ptr); 260 delete lang_id; 261 } 262 263 JNIEXPORT int JNICALL Java_android_view_textclassifier_LangId_nativeGetVersion( 264 JNIEnv* env, jobject clazz, jint fd) { 265 std::unique_ptr<LangId> lang_id(new LangId(fd)); 266 return lang_id->version(); 267 } 268