1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // JNI wrapper for the TextClassifier. 18 19 #include "textclassifier_jni.h" 20 21 #include <jni.h> 22 #include <type_traits> 23 #include <vector> 24 25 #include "text-classifier.h" 26 #include "util/base/integral_types.h" 27 #include "util/java/scoped_local_ref.h" 28 #include "util/java/string_utils.h" 29 #include "util/memory/mmap.h" 30 #include "util/utf8/unilib.h" 31 32 using libtextclassifier2::AnnotatedSpan; 33 using libtextclassifier2::AnnotationOptions; 34 using libtextclassifier2::ClassificationOptions; 35 using libtextclassifier2::ClassificationResult; 36 using libtextclassifier2::CodepointSpan; 37 using libtextclassifier2::JStringToUtf8String; 38 using libtextclassifier2::Model; 39 using libtextclassifier2::ScopedLocalRef; 40 using libtextclassifier2::SelectionOptions; 41 using libtextclassifier2::TextClassifier; 42 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU 43 using libtextclassifier2::UniLib; 44 #endif 45 46 namespace libtextclassifier2 { 47 48 using libtextclassifier2::CodepointSpan; 49 50 namespace { 51 52 std::string ToStlString(JNIEnv* env, const jstring& str) { 53 std::string result; 54 JStringToUtf8String(env, str, &result); 55 return result; 56 } 57 58 jobjectArray ClassificationResultsToJObjectArray( 59 JNIEnv* env, 60 const std::vector<ClassificationResult>& classification_result) { 61 const ScopedLocalRef<jclass> result_class( 62 env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult"), 63 env); 64 if (!result_class) { 65 TC_LOG(ERROR) << "Couldn't find ClassificationResult class."; 66 return nullptr; 67 } 68 const ScopedLocalRef<jclass> datetime_parse_class( 69 env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$DatetimeResult"), env); 70 if (!datetime_parse_class) { 71 TC_LOG(ERROR) << "Couldn't find DatetimeResult class."; 72 return nullptr; 73 } 74 75 const jmethodID result_class_constructor = 76 env->GetMethodID(result_class.get(), "<init>", 77 "(Ljava/lang/String;FL" TC_PACKAGE_PATH TC_CLASS_NAME_STR 78 "$DatetimeResult;)V"); 79 const jmethodID datetime_parse_class_constructor = 80 env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V"); 81 82 const jobjectArray results = env->NewObjectArray(classification_result.size(), 83 result_class.get(), nullptr); 84 for (int i = 0; i < classification_result.size(); i++) { 85 jstring row_string = 86 env->NewStringUTF(classification_result[i].collection.c_str()); 87 jobject row_datetime_parse = nullptr; 88 if (classification_result[i].datetime_parse_result.IsSet()) { 89 row_datetime_parse = env->NewObject( 90 datetime_parse_class.get(), datetime_parse_class_constructor, 91 classification_result[i].datetime_parse_result.time_ms_utc, 92 classification_result[i].datetime_parse_result.granularity); 93 } 94 jobject result = 95 env->NewObject(result_class.get(), result_class_constructor, row_string, 96 static_cast<jfloat>(classification_result[i].score), 97 row_datetime_parse); 98 env->SetObjectArrayElement(results, i, result); 99 env->DeleteLocalRef(result); 100 } 101 return results; 102 } 103 104 template <typename T, typename F> 105 std::pair<bool, T> CallJniMethod0(JNIEnv* env, jobject object, 106 jclass class_object, F function, 107 const std::string& method_name, 108 const std::string& return_java_type) { 109 const jmethodID method = env->GetMethodID(class_object, method_name.c_str(), 110 ("()" + return_java_type).c_str()); 111 if (!method) { 112 return std::make_pair(false, T()); 113 } 114 return std::make_pair(true, (env->*function)(object, method)); 115 } 116 117 SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) { 118 if (!joptions) { 119 return {}; 120 } 121 122 const ScopedLocalRef<jclass> options_class( 123 env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$SelectionOptions"), 124 env); 125 const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( 126 env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, 127 "getLocales", "Ljava/lang/String;"); 128 if (!status_or_locales.first) { 129 return {}; 130 } 131 132 SelectionOptions options; 133 options.locales = 134 ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); 135 136 return options; 137 } 138 139 template <typename T> 140 T FromJavaOptionsInternal(JNIEnv* env, jobject joptions, 141 const std::string& class_name) { 142 if (!joptions) { 143 return {}; 144 } 145 146 const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()), 147 env); 148 if (!options_class) { 149 return {}; 150 } 151 152 const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( 153 env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, 154 "getLocale", "Ljava/lang/String;"); 155 const std::pair<bool, jobject> status_or_reference_timezone = 156 CallJniMethod0<jobject>(env, joptions, options_class.get(), 157 &JNIEnv::CallObjectMethod, "getReferenceTimezone", 158 "Ljava/lang/String;"); 159 const std::pair<bool, int64> status_or_reference_time_ms_utc = 160 CallJniMethod0<int64>(env, joptions, options_class.get(), 161 &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc", 162 "J"); 163 164 if (!status_or_locales.first || !status_or_reference_timezone.first || 165 !status_or_reference_time_ms_utc.first) { 166 return {}; 167 } 168 169 T options; 170 options.locales = 171 ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); 172 options.reference_timezone = ToStlString( 173 env, reinterpret_cast<jstring>(status_or_reference_timezone.second)); 174 options.reference_time_ms_utc = status_or_reference_time_ms_utc.second; 175 return options; 176 } 177 178 ClassificationOptions FromJavaClassificationOptions(JNIEnv* env, 179 jobject joptions) { 180 return FromJavaOptionsInternal<ClassificationOptions>( 181 env, joptions, 182 TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationOptions"); 183 } 184 185 AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) { 186 return FromJavaOptionsInternal<AnnotationOptions>( 187 env, joptions, TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotationOptions"); 188 } 189 190 CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str, 191 CodepointSpan orig_indices, 192 bool from_utf8) { 193 const libtextclassifier2::UnicodeText unicode_str = 194 libtextclassifier2::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false); 195 196 int unicode_index = 0; 197 int bmp_index = 0; 198 199 const int* source_index; 200 const int* target_index; 201 if (from_utf8) { 202 source_index = &unicode_index; 203 target_index = &bmp_index; 204 } else { 205 source_index = &bmp_index; 206 target_index = &unicode_index; 207 } 208 209 CodepointSpan result{-1, -1}; 210 std::function<void()> assign_indices_fn = [&result, &orig_indices, 211 &source_index, &target_index]() { 212 if (orig_indices.first == *source_index) { 213 result.first = *target_index; 214 } 215 216 if (orig_indices.second == *source_index) { 217 result.second = *target_index; 218 } 219 }; 220 221 for (auto it = unicode_str.begin(); it != unicode_str.end(); 222 ++it, ++unicode_index, ++bmp_index) { 223 assign_indices_fn(); 224 225 // There is 1 extra character in the input for each UTF8 character > 0xFFFF. 226 if (*it > 0xFFFF) { 227 ++bmp_index; 228 } 229 } 230 assign_indices_fn(); 231 232 return result; 233 } 234 235 } // namespace 236 237 CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str, 238 CodepointSpan bmp_indices) { 239 return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false); 240 } 241 242 CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str, 243 CodepointSpan utf8_indices) { 244 return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true); 245 } 246 247 jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) { 248 // Get system-level file descriptor from AssetFileDescriptor. 249 ScopedLocalRef<jclass> afd_class( 250 env->FindClass("android/content/res/AssetFileDescriptor"), env); 251 if (afd_class == nullptr) { 252 TC_LOG(ERROR) << "Couldn't find AssetFileDescriptor."; 253 return reinterpret_cast<jlong>(nullptr); 254 } 255 jmethodID afd_class_getFileDescriptor = env->GetMethodID( 256 afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;"); 257 if (afd_class_getFileDescriptor == nullptr) { 258 TC_LOG(ERROR) << "Couldn't find getFileDescriptor."; 259 return reinterpret_cast<jlong>(nullptr); 260 } 261 262 ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"), 263 env); 264 if (fd_class == nullptr) { 265 TC_LOG(ERROR) << "Couldn't find FileDescriptor."; 266 return reinterpret_cast<jlong>(nullptr); 267 } 268 jfieldID fd_class_descriptor = 269 env->GetFieldID(fd_class.get(), "descriptor", "I"); 270 if (fd_class_descriptor == nullptr) { 271 TC_LOG(ERROR) << "Couldn't find descriptor."; 272 return reinterpret_cast<jlong>(nullptr); 273 } 274 275 jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor); 276 return env->GetIntField(bundle_jfd, fd_class_descriptor); 277 } 278 279 jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) { 280 if (!mmap->handle().ok()) { 281 return env->NewStringUTF(""); 282 } 283 const Model* model = libtextclassifier2::ViewModel( 284 mmap->handle().start(), mmap->handle().num_bytes()); 285 if (!model || !model->locales()) { 286 return env->NewStringUTF(""); 287 } 288 return env->NewStringUTF(model->locales()->c_str()); 289 } 290 291 jint GetVersionFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) { 292 if (!mmap->handle().ok()) { 293 return 0; 294 } 295 const Model* model = libtextclassifier2::ViewModel( 296 mmap->handle().start(), mmap->handle().num_bytes()); 297 if (!model) { 298 return 0; 299 } 300 return model->version(); 301 } 302 303 jstring GetNameFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) { 304 if (!mmap->handle().ok()) { 305 return env->NewStringUTF(""); 306 } 307 const Model* model = libtextclassifier2::ViewModel( 308 mmap->handle().start(), mmap->handle().num_bytes()); 309 if (!model || !model->name()) { 310 return env->NewStringUTF(""); 311 } 312 return env->NewStringUTF(model->name()->c_str()); 313 } 314 315 } // namespace libtextclassifier2 316 317 using libtextclassifier2::ClassificationResultsToJObjectArray; 318 using libtextclassifier2::ConvertIndicesBMPToUTF8; 319 using libtextclassifier2::ConvertIndicesUTF8ToBMP; 320 using libtextclassifier2::FromJavaAnnotationOptions; 321 using libtextclassifier2::FromJavaClassificationOptions; 322 using libtextclassifier2::FromJavaSelectionOptions; 323 using libtextclassifier2::ToStlString; 324 325 JNI_METHOD(jlong, TC_CLASS_NAME, nativeNew) 326 (JNIEnv* env, jobject thiz, jint fd) { 327 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU 328 return reinterpret_cast<jlong>( 329 TextClassifier::FromFileDescriptor(fd).release(), new UniLib(env)); 330 #else 331 return reinterpret_cast<jlong>( 332 TextClassifier::FromFileDescriptor(fd).release()); 333 #endif 334 } 335 336 JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromPath) 337 (JNIEnv* env, jobject thiz, jstring path) { 338 const std::string path_str = ToStlString(env, path); 339 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU 340 return reinterpret_cast<jlong>( 341 TextClassifier::FromPath(path_str, new UniLib(env)).release()); 342 #else 343 return reinterpret_cast<jlong>(TextClassifier::FromPath(path_str).release()); 344 #endif 345 } 346 347 JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor) 348 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { 349 const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); 350 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU 351 return reinterpret_cast<jlong>( 352 TextClassifier::FromFileDescriptor(fd, offset, size, new UniLib(env)) 353 .release()); 354 #else 355 return reinterpret_cast<jlong>( 356 TextClassifier::FromFileDescriptor(fd, offset, size).release()); 357 #endif 358 } 359 360 JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggestSelection) 361 (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, 362 jint selection_end, jobject options) { 363 if (!ptr) { 364 return nullptr; 365 } 366 367 TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr); 368 369 const std::string context_utf8 = ToStlString(env, context); 370 CodepointSpan input_indices = 371 ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); 372 CodepointSpan selection = model->SuggestSelection( 373 context_utf8, input_indices, FromJavaSelectionOptions(env, options)); 374 selection = ConvertIndicesUTF8ToBMP(context_utf8, selection); 375 376 jintArray result = env->NewIntArray(2); 377 env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection))); 378 env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection))); 379 return result; 380 } 381 382 JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText) 383 (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, 384 jint selection_end, jobject options) { 385 if (!ptr) { 386 return nullptr; 387 } 388 TextClassifier* ff_model = reinterpret_cast<TextClassifier*>(ptr); 389 390 const std::string context_utf8 = ToStlString(env, context); 391 const CodepointSpan input_indices = 392 ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); 393 const std::vector<ClassificationResult> classification_result = 394 ff_model->ClassifyText(context_utf8, input_indices, 395 FromJavaClassificationOptions(env, options)); 396 397 return ClassificationResultsToJObjectArray(env, classification_result); 398 } 399 400 JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate) 401 (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) { 402 if (!ptr) { 403 return nullptr; 404 } 405 TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr); 406 std::string context_utf8 = ToStlString(env, context); 407 std::vector<AnnotatedSpan> annotations = 408 model->Annotate(context_utf8, FromJavaAnnotationOptions(env, options)); 409 410 jclass result_class = 411 env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan"); 412 if (!result_class) { 413 TC_LOG(ERROR) << "Couldn't find result class: " 414 << TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan"; 415 return nullptr; 416 } 417 418 jmethodID result_class_constructor = env->GetMethodID( 419 result_class, "<init>", 420 "(II[L" TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult;)V"); 421 422 jobjectArray results = 423 env->NewObjectArray(annotations.size(), result_class, nullptr); 424 425 for (int i = 0; i < annotations.size(); ++i) { 426 CodepointSpan span_bmp = 427 ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span); 428 jobject result = env->NewObject( 429 result_class, result_class_constructor, 430 static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second), 431 ClassificationResultsToJObjectArray(env, 432 433 annotations[i].classification)); 434 env->SetObjectArrayElement(results, i, result); 435 env->DeleteLocalRef(result); 436 } 437 env->DeleteLocalRef(result_class); 438 return results; 439 } 440 441 JNI_METHOD(void, TC_CLASS_NAME, nativeClose) 442 (JNIEnv* env, jobject thiz, jlong ptr) { 443 TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr); 444 delete model; 445 } 446 447 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage) 448 (JNIEnv* env, jobject clazz, jint fd) { 449 TC_LOG(WARNING) << "Using deprecated getLanguage()."; 450 return JNI_METHOD_NAME(TC_CLASS_NAME, nativeGetLocales)(env, clazz, fd); 451 } 452 453 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocales) 454 (JNIEnv* env, jobject clazz, jint fd) { 455 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( 456 new libtextclassifier2::ScopedMmap(fd)); 457 return GetLocalesFromMmap(env, mmap.get()); 458 } 459 460 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocalesFromAssetFileDescriptor) 461 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { 462 const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); 463 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( 464 new libtextclassifier2::ScopedMmap(fd, offset, size)); 465 return GetLocalesFromMmap(env, mmap.get()); 466 } 467 468 JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion) 469 (JNIEnv* env, jobject clazz, jint fd) { 470 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( 471 new libtextclassifier2::ScopedMmap(fd)); 472 return GetVersionFromMmap(env, mmap.get()); 473 } 474 475 JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersionFromAssetFileDescriptor) 476 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { 477 const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); 478 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( 479 new libtextclassifier2::ScopedMmap(fd, offset, size)); 480 return GetVersionFromMmap(env, mmap.get()); 481 } 482 483 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetName) 484 (JNIEnv* env, jobject clazz, jint fd) { 485 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( 486 new libtextclassifier2::ScopedMmap(fd)); 487 return GetNameFromMmap(env, mmap.get()); 488 } 489 490 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetNameFromAssetFileDescriptor) 491 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { 492 const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); 493 const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( 494 new libtextclassifier2::ScopedMmap(fd, offset, size)); 495 return GetNameFromMmap(env, mmap.get()); 496 } 497