Home | History | Annotate | Download | only in libtextclassifier
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // JNI wrapper for the TextClassifier.
     18 
     19 #include "textclassifier_jni.h"
     20 
     21 #include <jni.h>
     22 #include <type_traits>
     23 #include <vector>
     24 
     25 #include "text-classifier.h"
     26 #include "util/base/integral_types.h"
     27 #include "util/java/scoped_local_ref.h"
     28 #include "util/java/string_utils.h"
     29 #include "util/memory/mmap.h"
     30 #include "util/utf8/unilib.h"
     31 
     32 using libtextclassifier2::AnnotatedSpan;
     33 using libtextclassifier2::AnnotationOptions;
     34 using libtextclassifier2::ClassificationOptions;
     35 using libtextclassifier2::ClassificationResult;
     36 using libtextclassifier2::CodepointSpan;
     37 using libtextclassifier2::JStringToUtf8String;
     38 using libtextclassifier2::Model;
     39 using libtextclassifier2::ScopedLocalRef;
     40 using libtextclassifier2::SelectionOptions;
     41 using libtextclassifier2::TextClassifier;
     42 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
     43 using libtextclassifier2::UniLib;
     44 #endif
     45 
     46 namespace libtextclassifier2 {
     47 
     48 using libtextclassifier2::CodepointSpan;
     49 
     50 namespace {
     51 
     52 std::string ToStlString(JNIEnv* env, const jstring& str) {
     53   std::string result;
     54   JStringToUtf8String(env, str, &result);
     55   return result;
     56 }
     57 
     58 jobjectArray ClassificationResultsToJObjectArray(
     59     JNIEnv* env,
     60     const std::vector<ClassificationResult>& classification_result) {
     61   const ScopedLocalRef<jclass> result_class(
     62       env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult"),
     63       env);
     64   if (!result_class) {
     65     TC_LOG(ERROR) << "Couldn't find ClassificationResult class.";
     66     return nullptr;
     67   }
     68   const ScopedLocalRef<jclass> datetime_parse_class(
     69       env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$DatetimeResult"), env);
     70   if (!datetime_parse_class) {
     71     TC_LOG(ERROR) << "Couldn't find DatetimeResult class.";
     72     return nullptr;
     73   }
     74 
     75   const jmethodID result_class_constructor =
     76       env->GetMethodID(result_class.get(), "<init>",
     77                        "(Ljava/lang/String;FL" TC_PACKAGE_PATH TC_CLASS_NAME_STR
     78                        "$DatetimeResult;)V");
     79   const jmethodID datetime_parse_class_constructor =
     80       env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V");
     81 
     82   const jobjectArray results = env->NewObjectArray(classification_result.size(),
     83                                                    result_class.get(), nullptr);
     84   for (int i = 0; i < classification_result.size(); i++) {
     85     jstring row_string =
     86         env->NewStringUTF(classification_result[i].collection.c_str());
     87     jobject row_datetime_parse = nullptr;
     88     if (classification_result[i].datetime_parse_result.IsSet()) {
     89       row_datetime_parse = env->NewObject(
     90           datetime_parse_class.get(), datetime_parse_class_constructor,
     91           classification_result[i].datetime_parse_result.time_ms_utc,
     92           classification_result[i].datetime_parse_result.granularity);
     93     }
     94     jobject result =
     95         env->NewObject(result_class.get(), result_class_constructor, row_string,
     96                        static_cast<jfloat>(classification_result[i].score),
     97                        row_datetime_parse);
     98     env->SetObjectArrayElement(results, i, result);
     99     env->DeleteLocalRef(result);
    100   }
    101   return results;
    102 }
    103 
    104 template <typename T, typename F>
    105 std::pair<bool, T> CallJniMethod0(JNIEnv* env, jobject object,
    106                                   jclass class_object, F function,
    107                                   const std::string& method_name,
    108                                   const std::string& return_java_type) {
    109   const jmethodID method = env->GetMethodID(class_object, method_name.c_str(),
    110                                             ("()" + return_java_type).c_str());
    111   if (!method) {
    112     return std::make_pair(false, T());
    113   }
    114   return std::make_pair(true, (env->*function)(object, method));
    115 }
    116 
    117 SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) {
    118   if (!joptions) {
    119     return {};
    120   }
    121 
    122   const ScopedLocalRef<jclass> options_class(
    123       env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$SelectionOptions"),
    124       env);
    125   const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
    126       env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
    127       "getLocales", "Ljava/lang/String;");
    128   if (!status_or_locales.first) {
    129     return {};
    130   }
    131 
    132   SelectionOptions options;
    133   options.locales =
    134       ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
    135 
    136   return options;
    137 }
    138 
    139 template <typename T>
    140 T FromJavaOptionsInternal(JNIEnv* env, jobject joptions,
    141                           const std::string& class_name) {
    142   if (!joptions) {
    143     return {};
    144   }
    145 
    146   const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()),
    147                                              env);
    148   if (!options_class) {
    149     return {};
    150   }
    151 
    152   const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
    153       env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
    154       "getLocale", "Ljava/lang/String;");
    155   const std::pair<bool, jobject> status_or_reference_timezone =
    156       CallJniMethod0<jobject>(env, joptions, options_class.get(),
    157                               &JNIEnv::CallObjectMethod, "getReferenceTimezone",
    158                               "Ljava/lang/String;");
    159   const std::pair<bool, int64> status_or_reference_time_ms_utc =
    160       CallJniMethod0<int64>(env, joptions, options_class.get(),
    161                             &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc",
    162                             "J");
    163 
    164   if (!status_or_locales.first || !status_or_reference_timezone.first ||
    165       !status_or_reference_time_ms_utc.first) {
    166     return {};
    167   }
    168 
    169   T options;
    170   options.locales =
    171       ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
    172   options.reference_timezone = ToStlString(
    173       env, reinterpret_cast<jstring>(status_or_reference_timezone.second));
    174   options.reference_time_ms_utc = status_or_reference_time_ms_utc.second;
    175   return options;
    176 }
    177 
    178 ClassificationOptions FromJavaClassificationOptions(JNIEnv* env,
    179                                                     jobject joptions) {
    180   return FromJavaOptionsInternal<ClassificationOptions>(
    181       env, joptions,
    182       TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationOptions");
    183 }
    184 
    185 AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) {
    186   return FromJavaOptionsInternal<AnnotationOptions>(
    187       env, joptions, TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotationOptions");
    188 }
    189 
    190 CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str,
    191                                     CodepointSpan orig_indices,
    192                                     bool from_utf8) {
    193   const libtextclassifier2::UnicodeText unicode_str =
    194       libtextclassifier2::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false);
    195 
    196   int unicode_index = 0;
    197   int bmp_index = 0;
    198 
    199   const int* source_index;
    200   const int* target_index;
    201   if (from_utf8) {
    202     source_index = &unicode_index;
    203     target_index = &bmp_index;
    204   } else {
    205     source_index = &bmp_index;
    206     target_index = &unicode_index;
    207   }
    208 
    209   CodepointSpan result{-1, -1};
    210   std::function<void()> assign_indices_fn = [&result, &orig_indices,
    211                                              &source_index, &target_index]() {
    212     if (orig_indices.first == *source_index) {
    213       result.first = *target_index;
    214     }
    215 
    216     if (orig_indices.second == *source_index) {
    217       result.second = *target_index;
    218     }
    219   };
    220 
    221   for (auto it = unicode_str.begin(); it != unicode_str.end();
    222        ++it, ++unicode_index, ++bmp_index) {
    223     assign_indices_fn();
    224 
    225     // There is 1 extra character in the input for each UTF8 character > 0xFFFF.
    226     if (*it > 0xFFFF) {
    227       ++bmp_index;
    228     }
    229   }
    230   assign_indices_fn();
    231 
    232   return result;
    233 }
    234 
    235 }  // namespace
    236 
    237 CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str,
    238                                       CodepointSpan bmp_indices) {
    239   return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false);
    240 }
    241 
    242 CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str,
    243                                       CodepointSpan utf8_indices) {
    244   return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true);
    245 }
    246 
    247 jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) {
    248   // Get system-level file descriptor from AssetFileDescriptor.
    249   ScopedLocalRef<jclass> afd_class(
    250       env->FindClass("android/content/res/AssetFileDescriptor"), env);
    251   if (afd_class == nullptr) {
    252     TC_LOG(ERROR) << "Couldn't find AssetFileDescriptor.";
    253     return reinterpret_cast<jlong>(nullptr);
    254   }
    255   jmethodID afd_class_getFileDescriptor = env->GetMethodID(
    256       afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;");
    257   if (afd_class_getFileDescriptor == nullptr) {
    258     TC_LOG(ERROR) << "Couldn't find getFileDescriptor.";
    259     return reinterpret_cast<jlong>(nullptr);
    260   }
    261 
    262   ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"),
    263                                   env);
    264   if (fd_class == nullptr) {
    265     TC_LOG(ERROR) << "Couldn't find FileDescriptor.";
    266     return reinterpret_cast<jlong>(nullptr);
    267   }
    268   jfieldID fd_class_descriptor =
    269       env->GetFieldID(fd_class.get(), "descriptor", "I");
    270   if (fd_class_descriptor == nullptr) {
    271     TC_LOG(ERROR) << "Couldn't find descriptor.";
    272     return reinterpret_cast<jlong>(nullptr);
    273   }
    274 
    275   jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor);
    276   return env->GetIntField(bundle_jfd, fd_class_descriptor);
    277 }
    278 
    279 jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) {
    280   if (!mmap->handle().ok()) {
    281     return env->NewStringUTF("");
    282   }
    283   const Model* model = libtextclassifier2::ViewModel(
    284       mmap->handle().start(), mmap->handle().num_bytes());
    285   if (!model || !model->locales()) {
    286     return env->NewStringUTF("");
    287   }
    288   return env->NewStringUTF(model->locales()->c_str());
    289 }
    290 
    291 jint GetVersionFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) {
    292   if (!mmap->handle().ok()) {
    293     return 0;
    294   }
    295   const Model* model = libtextclassifier2::ViewModel(
    296       mmap->handle().start(), mmap->handle().num_bytes());
    297   if (!model) {
    298     return 0;
    299   }
    300   return model->version();
    301 }
    302 
    303 jstring GetNameFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) {
    304   if (!mmap->handle().ok()) {
    305     return env->NewStringUTF("");
    306   }
    307   const Model* model = libtextclassifier2::ViewModel(
    308       mmap->handle().start(), mmap->handle().num_bytes());
    309   if (!model || !model->name()) {
    310     return env->NewStringUTF("");
    311   }
    312   return env->NewStringUTF(model->name()->c_str());
    313 }
    314 
    315 }  // namespace libtextclassifier2
    316 
    317 using libtextclassifier2::ClassificationResultsToJObjectArray;
    318 using libtextclassifier2::ConvertIndicesBMPToUTF8;
    319 using libtextclassifier2::ConvertIndicesUTF8ToBMP;
    320 using libtextclassifier2::FromJavaAnnotationOptions;
    321 using libtextclassifier2::FromJavaClassificationOptions;
    322 using libtextclassifier2::FromJavaSelectionOptions;
    323 using libtextclassifier2::ToStlString;
    324 
    325 JNI_METHOD(jlong, TC_CLASS_NAME, nativeNew)
    326 (JNIEnv* env, jobject thiz, jint fd) {
    327 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
    328   return reinterpret_cast<jlong>(
    329       TextClassifier::FromFileDescriptor(fd).release(), new UniLib(env));
    330 #else
    331   return reinterpret_cast<jlong>(
    332       TextClassifier::FromFileDescriptor(fd).release());
    333 #endif
    334 }
    335 
    336 JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromPath)
    337 (JNIEnv* env, jobject thiz, jstring path) {
    338   const std::string path_str = ToStlString(env, path);
    339 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
    340   return reinterpret_cast<jlong>(
    341       TextClassifier::FromPath(path_str, new UniLib(env)).release());
    342 #else
    343   return reinterpret_cast<jlong>(TextClassifier::FromPath(path_str).release());
    344 #endif
    345 }
    346 
    347 JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor)
    348 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
    349   const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
    350 #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU
    351   return reinterpret_cast<jlong>(
    352       TextClassifier::FromFileDescriptor(fd, offset, size, new UniLib(env))
    353           .release());
    354 #else
    355   return reinterpret_cast<jlong>(
    356       TextClassifier::FromFileDescriptor(fd, offset, size).release());
    357 #endif
    358 }
    359 
    360 JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggestSelection)
    361 (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
    362  jint selection_end, jobject options) {
    363   if (!ptr) {
    364     return nullptr;
    365   }
    366 
    367   TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
    368 
    369   const std::string context_utf8 = ToStlString(env, context);
    370   CodepointSpan input_indices =
    371       ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
    372   CodepointSpan selection = model->SuggestSelection(
    373       context_utf8, input_indices, FromJavaSelectionOptions(env, options));
    374   selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
    375 
    376   jintArray result = env->NewIntArray(2);
    377   env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection)));
    378   env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection)));
    379   return result;
    380 }
    381 
    382 JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText)
    383 (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin,
    384  jint selection_end, jobject options) {
    385   if (!ptr) {
    386     return nullptr;
    387   }
    388   TextClassifier* ff_model = reinterpret_cast<TextClassifier*>(ptr);
    389 
    390   const std::string context_utf8 = ToStlString(env, context);
    391   const CodepointSpan input_indices =
    392       ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
    393   const std::vector<ClassificationResult> classification_result =
    394       ff_model->ClassifyText(context_utf8, input_indices,
    395                              FromJavaClassificationOptions(env, options));
    396 
    397   return ClassificationResultsToJObjectArray(env, classification_result);
    398 }
    399 
    400 JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate)
    401 (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) {
    402   if (!ptr) {
    403     return nullptr;
    404   }
    405   TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
    406   std::string context_utf8 = ToStlString(env, context);
    407   std::vector<AnnotatedSpan> annotations =
    408       model->Annotate(context_utf8, FromJavaAnnotationOptions(env, options));
    409 
    410   jclass result_class =
    411       env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan");
    412   if (!result_class) {
    413     TC_LOG(ERROR) << "Couldn't find result class: "
    414                   << TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan";
    415     return nullptr;
    416   }
    417 
    418   jmethodID result_class_constructor = env->GetMethodID(
    419       result_class, "<init>",
    420       "(II[L" TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult;)V");
    421 
    422   jobjectArray results =
    423       env->NewObjectArray(annotations.size(), result_class, nullptr);
    424 
    425   for (int i = 0; i < annotations.size(); ++i) {
    426     CodepointSpan span_bmp =
    427         ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
    428     jobject result = env->NewObject(
    429         result_class, result_class_constructor,
    430         static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second),
    431         ClassificationResultsToJObjectArray(env,
    432 
    433                                             annotations[i].classification));
    434     env->SetObjectArrayElement(results, i, result);
    435     env->DeleteLocalRef(result);
    436   }
    437   env->DeleteLocalRef(result_class);
    438   return results;
    439 }
    440 
    441 JNI_METHOD(void, TC_CLASS_NAME, nativeClose)
    442 (JNIEnv* env, jobject thiz, jlong ptr) {
    443   TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr);
    444   delete model;
    445 }
    446 
    447 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage)
    448 (JNIEnv* env, jobject clazz, jint fd) {
    449   TC_LOG(WARNING) << "Using deprecated getLanguage().";
    450   return JNI_METHOD_NAME(TC_CLASS_NAME, nativeGetLocales)(env, clazz, fd);
    451 }
    452 
    453 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocales)
    454 (JNIEnv* env, jobject clazz, jint fd) {
    455   const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
    456       new libtextclassifier2::ScopedMmap(fd));
    457   return GetLocalesFromMmap(env, mmap.get());
    458 }
    459 
    460 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocalesFromAssetFileDescriptor)
    461 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
    462   const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
    463   const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
    464       new libtextclassifier2::ScopedMmap(fd, offset, size));
    465   return GetLocalesFromMmap(env, mmap.get());
    466 }
    467 
    468 JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion)
    469 (JNIEnv* env, jobject clazz, jint fd) {
    470   const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
    471       new libtextclassifier2::ScopedMmap(fd));
    472   return GetVersionFromMmap(env, mmap.get());
    473 }
    474 
    475 JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersionFromAssetFileDescriptor)
    476 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
    477   const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
    478   const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
    479       new libtextclassifier2::ScopedMmap(fd, offset, size));
    480   return GetVersionFromMmap(env, mmap.get());
    481 }
    482 
    483 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetName)
    484 (JNIEnv* env, jobject clazz, jint fd) {
    485   const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
    486       new libtextclassifier2::ScopedMmap(fd));
    487   return GetNameFromMmap(env, mmap.get());
    488 }
    489 
    490 JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetNameFromAssetFileDescriptor)
    491 (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) {
    492   const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd);
    493   const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap(
    494       new libtextclassifier2::ScopedMmap(fd, offset, size));
    495   return GetNameFromMmap(env, mmap.get());
    496 }
    497