Home | History | Annotate | Download | only in actions
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // JNI wrapper for actions.
     18 
     19 #include "actions/actions_jni.h"
     20 
     21 #include <jni.h>
     22 #include <map>
     23 #include <type_traits>
     24 #include <vector>
     25 
     26 #include "actions/actions-suggestions.h"
     27 #include "annotator/annotator.h"
     28 #include "annotator/annotator_jni_common.h"
     29 #include "utils/base/integral_types.h"
     30 #include "utils/intents/intent-generator.h"
     31 #include "utils/intents/jni.h"
     32 #include "utils/java/jni-cache.h"
     33 #include "utils/java/scoped_local_ref.h"
     34 #include "utils/java/string_utils.h"
     35 #include "utils/memory/mmap.h"
     36 
     37 using libtextclassifier3::ActionsSuggestions;
     38 using libtextclassifier3::ActionsSuggestionsResponse;
     39 using libtextclassifier3::ActionSuggestion;
     40 using libtextclassifier3::ActionSuggestionOptions;
     41 using libtextclassifier3::Annotator;
     42 using libtextclassifier3::Conversation;
     43 using libtextclassifier3::IntentGenerator;
     44 using libtextclassifier3::ScopedLocalRef;
     45 using libtextclassifier3::ToStlString;
     46 
     47 // When using the Java's ICU, UniLib needs to be instantiated with a JavaVM
     48 // pointer from JNI. When using a standard ICU the pointer is not needed and the
     49 // objects are instantiated implicitly.
     50 #ifdef TC3_UNILIB_JAVAICU
     51 using libtextclassifier3::UniLib;
     52 #endif
     53 
     54 namespace libtextclassifier3 {
     55 
     56 namespace {
     57 
     58 // Cached state for model inference.
     59 // Keeps a jni cache, intent generator and model instance so that they don't
     60 // have to be recreated for each call.
     61 class ActionsSuggestionsJniContext {
     62  public:
     63   static ActionsSuggestionsJniContext* Create(
     64       const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
     65       std::unique_ptr<ActionsSuggestions> model) {
     66     if (jni_cache == nullptr || model == nullptr) {
     67       return nullptr;
     68     }
     69     std::unique_ptr<IntentGenerator> intent_generator =
     70         IntentGenerator::Create(model->model()->android_intent_options(),
     71                                 model->model()->resources(), jni_cache);
     72     std::unique_ptr<RemoteActionTemplatesHandler> template_handler =
     73         libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache);
     74 
     75     if (intent_generator == nullptr || template_handler == nullptr) {
     76       return nullptr;
     77     }
     78 
     79     return new ActionsSuggestionsJniContext(jni_cache, std::move(model),
     80                                             std::move(intent_generator),
     81                                             std::move(template_handler));
     82   }
     83 
     84   std::shared_ptr<libtextclassifier3::JniCache> jni_cache() const {
     85     return jni_cache_;
     86   }
     87 
     88   ActionsSuggestions* model() const { return model_.get(); }
     89 
     90   IntentGenerator* intent_generator() const { return intent_generator_.get(); }
     91 
     92   RemoteActionTemplatesHandler* template_handler() const {
     93     return template_handler_.get();
     94   }
     95 
     96  private:
     97   ActionsSuggestionsJniContext(
     98       const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache,
     99       std::unique_ptr<ActionsSuggestions> model,
    100       std::unique_ptr<IntentGenerator> intent_generator,
    101       std::unique_ptr<RemoteActionTemplatesHandler> template_handler)
    102       : jni_cache_(jni_cache),
    103         model_(std::move(model)),
    104         intent_generator_(std::move(intent_generator)),
    105         template_handler_(std::move(template_handler)) {}
    106 
    107   std::shared_ptr<libtextclassifier3::JniCache> jni_cache_;
    108   std::unique_ptr<ActionsSuggestions> model_;
    109   std::unique_ptr<IntentGenerator> intent_generator_;
    110   std::unique_ptr<RemoteActionTemplatesHandler> template_handler_;
    111 };
    112 
    113 ActionSuggestionOptions FromJavaActionSuggestionOptions(JNIEnv* env,
    114                                                         jobject joptions) {
    115   ActionSuggestionOptions options = ActionSuggestionOptions::Default();
    116   return options;
    117 }
    118 
    119 jobjectArray ActionSuggestionsToJObjectArray(
    120     JNIEnv* env, const ActionsSuggestionsJniContext* context,
    121     jobject app_context,
    122     const reflection::Schema* annotations_entity_data_schema,
    123     const std::vector<ActionSuggestion>& action_result,
    124     const Conversation& conversation, const jstring device_locales,
    125     const bool generate_intents) {
    126   const ScopedLocalRef<jclass> result_class(
    127       env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
    128                      "$ActionSuggestion"),
    129       env);
    130   if (!result_class) {
    131     TC3_LOG(ERROR) << "Couldn't find ActionSuggestion class.";
    132     return nullptr;
    133   }
    134 
    135   const jmethodID result_class_constructor = env->GetMethodID(
    136       result_class.get(), "<init>",
    137       "(Ljava/lang/String;Ljava/lang/String;F[L" TC3_PACKAGE_PATH
    138           TC3_NAMED_VARIANT_CLASS_NAME_STR
    139       ";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";)V");
    140   const jobjectArray results =
    141       env->NewObjectArray(action_result.size(), result_class.get(), nullptr);
    142   for (int i = 0; i < action_result.size(); i++) {
    143     jobject extras = nullptr;
    144 
    145     const reflection::Schema* actions_entity_data_schema =
    146         context->model()->entity_data_schema();
    147     if (actions_entity_data_schema != nullptr &&
    148         !action_result[i].serialized_entity_data.empty()) {
    149       extras = context->template_handler()->EntityDataAsNamedVariantArray(
    150           actions_entity_data_schema, action_result[i].serialized_entity_data);
    151     }
    152 
    153     jbyteArray serialized_entity_data = nullptr;
    154     if (!action_result[i].serialized_entity_data.empty()) {
    155       serialized_entity_data =
    156           env->NewByteArray(action_result[i].serialized_entity_data.size());
    157       env->SetByteArrayRegion(
    158           serialized_entity_data, 0,
    159           action_result[i].serialized_entity_data.size(),
    160           reinterpret_cast<const jbyte*>(
    161               action_result[i].serialized_entity_data.data()));
    162     }
    163 
    164     jobject remote_action_templates_result = nullptr;
    165     if (generate_intents) {
    166       std::vector<RemoteActionTemplate> remote_action_templates;
    167       if (context->intent_generator()->GenerateIntents(
    168               device_locales, action_result[i], conversation, app_context,
    169               actions_entity_data_schema, annotations_entity_data_schema,
    170               &remote_action_templates)) {
    171         remote_action_templates_result =
    172             context->template_handler()->RemoteActionTemplatesToJObjectArray(
    173                 remote_action_templates);
    174       }
    175     }
    176 
    177     ScopedLocalRef<jstring> reply = context->jni_cache()->ConvertToJavaString(
    178         action_result[i].response_text);
    179 
    180     ScopedLocalRef<jobject> result(env->NewObject(
    181         result_class.get(), result_class_constructor, reply.get(),
    182         env->NewStringUTF(action_result[i].type.c_str()),
    183         static_cast<jfloat>(action_result[i].score), extras,
    184         serialized_entity_data, remote_action_templates_result));
    185     env->SetObjectArrayElement(results, i, result.get());
    186   }
    187   return results;
    188 }
    189 
    190 ConversationMessage FromJavaConversationMessage(JNIEnv* env, jobject jmessage) {
    191   if (!jmessage) {
    192     return {};
    193   }
    194 
    195   const ScopedLocalRef<jclass> message_class(
    196       env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
    197                      "$ConversationMessage"),
    198       env);
    199   const std::pair<bool, jobject> status_or_text = CallJniMethod0<jobject>(
    200       env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod, "getText",
    201       "Ljava/lang/String;");
    202   const std::pair<bool, int32> status_or_user_id =
    203       CallJniMethod0<int32>(env, jmessage, message_class.get(),
    204                             &JNIEnv::CallIntMethod, "getUserId", "I");
    205   const std::pair<bool, int64> status_or_reference_time = CallJniMethod0<int64>(
    206       env, jmessage, message_class.get(), &JNIEnv::CallLongMethod,
    207       "getReferenceTimeMsUtc", "J");
    208   const std::pair<bool, jobject> status_or_reference_timezone =
    209       CallJniMethod0<jobject>(env, jmessage, message_class.get(),
    210                               &JNIEnv::CallObjectMethod, "getReferenceTimezone",
    211                               "Ljava/lang/String;");
    212   const std::pair<bool, jobject> status_or_detected_text_language_tags =
    213       CallJniMethod0<jobject>(
    214           env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod,
    215           "getDetectedTextLanguageTags", "Ljava/lang/String;");
    216   if (!status_or_text.first || !status_or_user_id.first ||
    217       !status_or_detected_text_language_tags.first ||
    218       !status_or_reference_time.first || !status_or_reference_timezone.first) {
    219     return {};
    220   }
    221 
    222   ConversationMessage message;
    223   message.text = ToStlString(env, static_cast<jstring>(status_or_text.second));
    224   message.user_id = status_or_user_id.second;
    225   message.reference_time_ms_utc = status_or_reference_time.second;
    226   message.reference_timezone = ToStlString(
    227       env, static_cast<jstring>(status_or_reference_timezone.second));
    228   message.detected_text_language_tags = ToStlString(
    229       env, static_cast<jstring>(status_or_detected_text_language_tags.second));
    230   return message;
    231 }
    232 
    233 Conversation FromJavaConversation(JNIEnv* env, jobject jconversation) {
    234   if (!jconversation) {
    235     return {};
    236   }
    237 
    238   const ScopedLocalRef<jclass> conversation_class(
    239       env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
    240                      "$Conversation"),
    241       env);
    242 
    243   const std::pair<bool, jobject> status_or_messages = CallJniMethod0<jobject>(
    244       env, jconversation, conversation_class.get(), &JNIEnv::CallObjectMethod,
    245       "getConversationMessages",
    246       "[L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ConversationMessage;");
    247 
    248   if (!status_or_messages.first) {
    249     return {};
    250   }
    251 
    252   const jobjectArray jmessages =
    253       reinterpret_cast<jobjectArray>(status_or_messages.second);
    254 
    255   const int size = env->GetArrayLength(jmessages);
    256 
    257   std::vector<ConversationMessage> messages;
    258   for (int i = 0; i < size; i++) {
    259     jobject jmessage = env->GetObjectArrayElement(jmessages, i);
    260     ConversationMessage message = FromJavaConversationMessage(env, jmessage);
    261     messages.push_back(message);
    262   }
    263   Conversation conversation;
    264   conversation.messages = messages;
    265   return conversation;
    266 }
    267 
    268 jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
    269   if (!mmap->handle().ok()) {
    270     return env->NewStringUTF("");
    271   }
    272   const ActionsModel* model = libtextclassifier3::ViewActionsModel(
    273       mmap->handle().start(), mmap->handle().num_bytes());
    274   if (!model || !model->locales()) {
    275     return env->NewStringUTF("");
    276   }
    277   return env->NewStringUTF(model->locales()->c_str());
    278 }
    279 
    280 jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
    281   if (!mmap->handle().ok()) {
    282     return 0;
    283   }
    284   const ActionsModel* model = libtextclassifier3::ViewActionsModel(
    285       mmap->handle().start(), mmap->handle().num_bytes());
    286   if (!model) {
    287     return 0;
    288   }
    289   return model->version();
    290 }
    291 
    292 jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
    293   if (!mmap->handle().ok()) {
    294     return env->NewStringUTF("");
    295   }
    296   const ActionsModel* model = libtextclassifier3::ViewActionsModel(
    297       mmap->handle().start(), mmap->handle().num_bytes());
    298   if (!model || !model->name()) {
    299     return env->NewStringUTF("");
    300   }
    301   return env->NewStringUTF(model->name()->c_str());
    302 }
    303 }  // namespace
    304 }  // namespace libtextclassifier3
    305 
    306 using libtextclassifier3::ActionsSuggestionsJniContext;
    307 using libtextclassifier3::ActionSuggestionsToJObjectArray;
    308 using libtextclassifier3::FromJavaActionSuggestionOptions;
    309 using libtextclassifier3::FromJavaConversation;
    310 
    311 TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel)
    312 (JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions) {
    313   std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
    314       libtextclassifier3::JniCache::Create(env);
    315   std::string preconditions;
    316   if (serialized_preconditions != nullptr &&
    317       !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
    318                                               &preconditions)) {
    319     TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
    320     return 0;
    321   }
    322 #ifdef TC3_UNILIB_JAVAICU
    323   return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
    324       jni_cache,
    325       ActionsSuggestions::FromFileDescriptor(
    326           fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)), preconditions)));
    327 #else
    328   return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
    329       jni_cache, ActionsSuggestions::FromFileDescriptor(fd, /*unilib=*/nullptr,
    330                                                         preconditions)));
    331 #endif  // TC3_UNILIB_JAVAICU
    332 }
    333 
    334 TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath)
    335 (JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions) {
    336   std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
    337       libtextclassifier3::JniCache::Create(env);
    338   const std::string path_str = ToStlString(env, path);
    339   std::string preconditions;
    340   if (serialized_preconditions != nullptr &&
    341       !libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
    342                                               &preconditions)) {
    343     TC3_LOG(ERROR) << "Could not convert serialized preconditions.";
    344     return 0;
    345   }
    346 #ifdef TC3_UNILIB_JAVAICU
    347   return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
    348       jni_cache, ActionsSuggestions::FromPath(
    349                      path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)),
    350                      preconditions)));
    351 #else
    352   return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create(
    353       jni_cache, ActionsSuggestions::FromPath(path_str, /*unilib=*/nullptr,
    354                                               preconditions)));
    355 #endif  // TC3_UNILIB_JAVAICU
    356 }
    357 
    358 TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
    359 (JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation, jobject joptions,
    360  jlong annotatorPtr, jobject app_context, jstring device_locales,
    361  jboolean generate_intents) {
    362   if (!ptr) {
    363     return nullptr;
    364   }
    365   const Conversation conversation = FromJavaConversation(env, jconversation);
    366   const ActionSuggestionOptions options =
    367       FromJavaActionSuggestionOptions(env, joptions);
    368   const ActionsSuggestionsJniContext* context =
    369       reinterpret_cast<ActionsSuggestionsJniContext*>(ptr);
    370   const Annotator* annotator = reinterpret_cast<Annotator*>(annotatorPtr);
    371 
    372   const ActionsSuggestionsResponse response =
    373       context->model()->SuggestActions(conversation, annotator, options);
    374 
    375   const reflection::Schema* anntotations_entity_data_schema =
    376       annotator ? annotator->entity_data_schema() : nullptr;
    377   return ActionSuggestionsToJObjectArray(
    378       env, context, app_context, anntotations_entity_data_schema,
    379       response.actions, conversation, device_locales, generate_intents);
    380 }
    381 
    382 TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
    383 (JNIEnv* env, jobject clazz, jlong model_ptr) {
    384   const ActionsSuggestionsJniContext* context =
    385       reinterpret_cast<ActionsSuggestionsJniContext*>(model_ptr);
    386   delete context;
    387 }
    388 
    389 TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales)
    390 (JNIEnv* env, jobject clazz, jint fd) {
    391   const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
    392       new libtextclassifier3::ScopedMmap(fd));
    393   return libtextclassifier3::GetLocalesFromMmap(env, mmap.get());
    394 }
    395 
    396 TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
    397 (JNIEnv* env, jobject clazz, jint fd) {
    398   const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
    399       new libtextclassifier3::ScopedMmap(fd));
    400   return libtextclassifier3::GetNameFromMmap(env, mmap.get());
    401 }
    402 
    403 TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
    404 (JNIEnv* env, jobject clazz, jint fd) {
    405   const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
    406       new libtextclassifier3::ScopedMmap(fd));
    407   return libtextclassifier3::GetVersionFromMmap(env, mmap.get());
    408 }
    409