1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // JNI wrapper for actions. 18 19 #include "actions/actions_jni.h" 20 21 #include <jni.h> 22 #include <map> 23 #include <type_traits> 24 #include <vector> 25 26 #include "actions/actions-suggestions.h" 27 #include "annotator/annotator.h" 28 #include "annotator/annotator_jni_common.h" 29 #include "utils/base/integral_types.h" 30 #include "utils/intents/intent-generator.h" 31 #include "utils/intents/jni.h" 32 #include "utils/java/jni-cache.h" 33 #include "utils/java/scoped_local_ref.h" 34 #include "utils/java/string_utils.h" 35 #include "utils/memory/mmap.h" 36 37 using libtextclassifier3::ActionsSuggestions; 38 using libtextclassifier3::ActionsSuggestionsResponse; 39 using libtextclassifier3::ActionSuggestion; 40 using libtextclassifier3::ActionSuggestionOptions; 41 using libtextclassifier3::Annotator; 42 using libtextclassifier3::Conversation; 43 using libtextclassifier3::IntentGenerator; 44 using libtextclassifier3::ScopedLocalRef; 45 using libtextclassifier3::ToStlString; 46 47 // When using the Java's ICU, UniLib needs to be instantiated with a JavaVM 48 // pointer from JNI. When using a standard ICU the pointer is not needed and the 49 // objects are instantiated implicitly. 50 #ifdef TC3_UNILIB_JAVAICU 51 using libtextclassifier3::UniLib; 52 #endif 53 54 namespace libtextclassifier3 { 55 56 namespace { 57 58 // Cached state for model inference. 59 // Keeps a jni cache, intent generator and model instance so that they don't 60 // have to be recreated for each call. 61 class ActionsSuggestionsJniContext { 62 public: 63 static ActionsSuggestionsJniContext* Create( 64 const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache, 65 std::unique_ptr<ActionsSuggestions> model) { 66 if (jni_cache == nullptr || model == nullptr) { 67 return nullptr; 68 } 69 std::unique_ptr<IntentGenerator> intent_generator = 70 IntentGenerator::Create(model->model()->android_intent_options(), 71 model->model()->resources(), jni_cache); 72 std::unique_ptr<RemoteActionTemplatesHandler> template_handler = 73 libtextclassifier3::RemoteActionTemplatesHandler::Create(jni_cache); 74 75 if (intent_generator == nullptr || template_handler == nullptr) { 76 return nullptr; 77 } 78 79 return new ActionsSuggestionsJniContext(jni_cache, std::move(model), 80 std::move(intent_generator), 81 std::move(template_handler)); 82 } 83 84 std::shared_ptr<libtextclassifier3::JniCache> jni_cache() const { 85 return jni_cache_; 86 } 87 88 ActionsSuggestions* model() const { return model_.get(); } 89 90 IntentGenerator* intent_generator() const { return intent_generator_.get(); } 91 92 RemoteActionTemplatesHandler* template_handler() const { 93 return template_handler_.get(); 94 } 95 96 private: 97 ActionsSuggestionsJniContext( 98 const std::shared_ptr<libtextclassifier3::JniCache>& jni_cache, 99 std::unique_ptr<ActionsSuggestions> model, 100 std::unique_ptr<IntentGenerator> intent_generator, 101 std::unique_ptr<RemoteActionTemplatesHandler> template_handler) 102 : jni_cache_(jni_cache), 103 model_(std::move(model)), 104 intent_generator_(std::move(intent_generator)), 105 template_handler_(std::move(template_handler)) {} 106 107 std::shared_ptr<libtextclassifier3::JniCache> jni_cache_; 108 std::unique_ptr<ActionsSuggestions> model_; 109 std::unique_ptr<IntentGenerator> intent_generator_; 110 std::unique_ptr<RemoteActionTemplatesHandler> template_handler_; 111 }; 112 113 ActionSuggestionOptions FromJavaActionSuggestionOptions(JNIEnv* env, 114 jobject joptions) { 115 ActionSuggestionOptions options = ActionSuggestionOptions::Default(); 116 return options; 117 } 118 119 jobjectArray ActionSuggestionsToJObjectArray( 120 JNIEnv* env, const ActionsSuggestionsJniContext* context, 121 jobject app_context, 122 const reflection::Schema* annotations_entity_data_schema, 123 const std::vector<ActionSuggestion>& action_result, 124 const Conversation& conversation, const jstring device_locales, 125 const bool generate_intents) { 126 const ScopedLocalRef<jclass> result_class( 127 env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR 128 "$ActionSuggestion"), 129 env); 130 if (!result_class) { 131 TC3_LOG(ERROR) << "Couldn't find ActionSuggestion class."; 132 return nullptr; 133 } 134 135 const jmethodID result_class_constructor = env->GetMethodID( 136 result_class.get(), "<init>", 137 "(Ljava/lang/String;Ljava/lang/String;F[L" TC3_PACKAGE_PATH 138 TC3_NAMED_VARIANT_CLASS_NAME_STR 139 ";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";)V"); 140 const jobjectArray results = 141 env->NewObjectArray(action_result.size(), result_class.get(), nullptr); 142 for (int i = 0; i < action_result.size(); i++) { 143 jobject extras = nullptr; 144 145 const reflection::Schema* actions_entity_data_schema = 146 context->model()->entity_data_schema(); 147 if (actions_entity_data_schema != nullptr && 148 !action_result[i].serialized_entity_data.empty()) { 149 extras = context->template_handler()->EntityDataAsNamedVariantArray( 150 actions_entity_data_schema, action_result[i].serialized_entity_data); 151 } 152 153 jbyteArray serialized_entity_data = nullptr; 154 if (!action_result[i].serialized_entity_data.empty()) { 155 serialized_entity_data = 156 env->NewByteArray(action_result[i].serialized_entity_data.size()); 157 env->SetByteArrayRegion( 158 serialized_entity_data, 0, 159 action_result[i].serialized_entity_data.size(), 160 reinterpret_cast<const jbyte*>( 161 action_result[i].serialized_entity_data.data())); 162 } 163 164 jobject remote_action_templates_result = nullptr; 165 if (generate_intents) { 166 std::vector<RemoteActionTemplate> remote_action_templates; 167 if (context->intent_generator()->GenerateIntents( 168 device_locales, action_result[i], conversation, app_context, 169 actions_entity_data_schema, annotations_entity_data_schema, 170 &remote_action_templates)) { 171 remote_action_templates_result = 172 context->template_handler()->RemoteActionTemplatesToJObjectArray( 173 remote_action_templates); 174 } 175 } 176 177 ScopedLocalRef<jstring> reply = context->jni_cache()->ConvertToJavaString( 178 action_result[i].response_text); 179 180 ScopedLocalRef<jobject> result(env->NewObject( 181 result_class.get(), result_class_constructor, reply.get(), 182 env->NewStringUTF(action_result[i].type.c_str()), 183 static_cast<jfloat>(action_result[i].score), extras, 184 serialized_entity_data, remote_action_templates_result)); 185 env->SetObjectArrayElement(results, i, result.get()); 186 } 187 return results; 188 } 189 190 ConversationMessage FromJavaConversationMessage(JNIEnv* env, jobject jmessage) { 191 if (!jmessage) { 192 return {}; 193 } 194 195 const ScopedLocalRef<jclass> message_class( 196 env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR 197 "$ConversationMessage"), 198 env); 199 const std::pair<bool, jobject> status_or_text = CallJniMethod0<jobject>( 200 env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod, "getText", 201 "Ljava/lang/String;"); 202 const std::pair<bool, int32> status_or_user_id = 203 CallJniMethod0<int32>(env, jmessage, message_class.get(), 204 &JNIEnv::CallIntMethod, "getUserId", "I"); 205 const std::pair<bool, int64> status_or_reference_time = CallJniMethod0<int64>( 206 env, jmessage, message_class.get(), &JNIEnv::CallLongMethod, 207 "getReferenceTimeMsUtc", "J"); 208 const std::pair<bool, jobject> status_or_reference_timezone = 209 CallJniMethod0<jobject>(env, jmessage, message_class.get(), 210 &JNIEnv::CallObjectMethod, "getReferenceTimezone", 211 "Ljava/lang/String;"); 212 const std::pair<bool, jobject> status_or_detected_text_language_tags = 213 CallJniMethod0<jobject>( 214 env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod, 215 "getDetectedTextLanguageTags", "Ljava/lang/String;"); 216 if (!status_or_text.first || !status_or_user_id.first || 217 !status_or_detected_text_language_tags.first || 218 !status_or_reference_time.first || !status_or_reference_timezone.first) { 219 return {}; 220 } 221 222 ConversationMessage message; 223 message.text = ToStlString(env, static_cast<jstring>(status_or_text.second)); 224 message.user_id = status_or_user_id.second; 225 message.reference_time_ms_utc = status_or_reference_time.second; 226 message.reference_timezone = ToStlString( 227 env, static_cast<jstring>(status_or_reference_timezone.second)); 228 message.detected_text_language_tags = ToStlString( 229 env, static_cast<jstring>(status_or_detected_text_language_tags.second)); 230 return message; 231 } 232 233 Conversation FromJavaConversation(JNIEnv* env, jobject jconversation) { 234 if (!jconversation) { 235 return {}; 236 } 237 238 const ScopedLocalRef<jclass> conversation_class( 239 env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR 240 "$Conversation"), 241 env); 242 243 const std::pair<bool, jobject> status_or_messages = CallJniMethod0<jobject>( 244 env, jconversation, conversation_class.get(), &JNIEnv::CallObjectMethod, 245 "getConversationMessages", 246 "[L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ConversationMessage;"); 247 248 if (!status_or_messages.first) { 249 return {}; 250 } 251 252 const jobjectArray jmessages = 253 reinterpret_cast<jobjectArray>(status_or_messages.second); 254 255 const int size = env->GetArrayLength(jmessages); 256 257 std::vector<ConversationMessage> messages; 258 for (int i = 0; i < size; i++) { 259 jobject jmessage = env->GetObjectArrayElement(jmessages, i); 260 ConversationMessage message = FromJavaConversationMessage(env, jmessage); 261 messages.push_back(message); 262 } 263 Conversation conversation; 264 conversation.messages = messages; 265 return conversation; 266 } 267 268 jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { 269 if (!mmap->handle().ok()) { 270 return env->NewStringUTF(""); 271 } 272 const ActionsModel* model = libtextclassifier3::ViewActionsModel( 273 mmap->handle().start(), mmap->handle().num_bytes()); 274 if (!model || !model->locales()) { 275 return env->NewStringUTF(""); 276 } 277 return env->NewStringUTF(model->locales()->c_str()); 278 } 279 280 jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { 281 if (!mmap->handle().ok()) { 282 return 0; 283 } 284 const ActionsModel* model = libtextclassifier3::ViewActionsModel( 285 mmap->handle().start(), mmap->handle().num_bytes()); 286 if (!model) { 287 return 0; 288 } 289 return model->version(); 290 } 291 292 jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) { 293 if (!mmap->handle().ok()) { 294 return env->NewStringUTF(""); 295 } 296 const ActionsModel* model = libtextclassifier3::ViewActionsModel( 297 mmap->handle().start(), mmap->handle().num_bytes()); 298 if (!model || !model->name()) { 299 return env->NewStringUTF(""); 300 } 301 return env->NewStringUTF(model->name()->c_str()); 302 } 303 } // namespace 304 } // namespace libtextclassifier3 305 306 using libtextclassifier3::ActionsSuggestionsJniContext; 307 using libtextclassifier3::ActionSuggestionsToJObjectArray; 308 using libtextclassifier3::FromJavaActionSuggestionOptions; 309 using libtextclassifier3::FromJavaConversation; 310 311 TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModel) 312 (JNIEnv* env, jobject thiz, jint fd, jbyteArray serialized_preconditions) { 313 std::shared_ptr<libtextclassifier3::JniCache> jni_cache = 314 libtextclassifier3::JniCache::Create(env); 315 std::string preconditions; 316 if (serialized_preconditions != nullptr && 317 !libtextclassifier3::JByteArrayToString(env, serialized_preconditions, 318 &preconditions)) { 319 TC3_LOG(ERROR) << "Could not convert serialized preconditions."; 320 return 0; 321 } 322 #ifdef TC3_UNILIB_JAVAICU 323 return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create( 324 jni_cache, 325 ActionsSuggestions::FromFileDescriptor( 326 fd, std::unique_ptr<UniLib>(new UniLib(jni_cache)), preconditions))); 327 #else 328 return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create( 329 jni_cache, ActionsSuggestions::FromFileDescriptor(fd, /*unilib=*/nullptr, 330 preconditions))); 331 #endif // TC3_UNILIB_JAVAICU 332 } 333 334 TC3_JNI_METHOD(jlong, TC3_ACTIONS_CLASS_NAME, nativeNewActionsModelFromPath) 335 (JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions) { 336 std::shared_ptr<libtextclassifier3::JniCache> jni_cache = 337 libtextclassifier3::JniCache::Create(env); 338 const std::string path_str = ToStlString(env, path); 339 std::string preconditions; 340 if (serialized_preconditions != nullptr && 341 !libtextclassifier3::JByteArrayToString(env, serialized_preconditions, 342 &preconditions)) { 343 TC3_LOG(ERROR) << "Could not convert serialized preconditions."; 344 return 0; 345 } 346 #ifdef TC3_UNILIB_JAVAICU 347 return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create( 348 jni_cache, ActionsSuggestions::FromPath( 349 path_str, std::unique_ptr<UniLib>(new UniLib(jni_cache)), 350 preconditions))); 351 #else 352 return reinterpret_cast<jlong>(ActionsSuggestionsJniContext::Create( 353 jni_cache, ActionsSuggestions::FromPath(path_str, /*unilib=*/nullptr, 354 preconditions))); 355 #endif // TC3_UNILIB_JAVAICU 356 } 357 358 TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions) 359 (JNIEnv* env, jobject clazz, jlong ptr, jobject jconversation, jobject joptions, 360 jlong annotatorPtr, jobject app_context, jstring device_locales, 361 jboolean generate_intents) { 362 if (!ptr) { 363 return nullptr; 364 } 365 const Conversation conversation = FromJavaConversation(env, jconversation); 366 const ActionSuggestionOptions options = 367 FromJavaActionSuggestionOptions(env, joptions); 368 const ActionsSuggestionsJniContext* context = 369 reinterpret_cast<ActionsSuggestionsJniContext*>(ptr); 370 const Annotator* annotator = reinterpret_cast<Annotator*>(annotatorPtr); 371 372 const ActionsSuggestionsResponse response = 373 context->model()->SuggestActions(conversation, annotator, options); 374 375 const reflection::Schema* anntotations_entity_data_schema = 376 annotator ? annotator->entity_data_schema() : nullptr; 377 return ActionSuggestionsToJObjectArray( 378 env, context, app_context, anntotations_entity_data_schema, 379 response.actions, conversation, device_locales, generate_intents); 380 } 381 382 TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel) 383 (JNIEnv* env, jobject clazz, jlong model_ptr) { 384 const ActionsSuggestionsJniContext* context = 385 reinterpret_cast<ActionsSuggestionsJniContext*>(model_ptr); 386 delete context; 387 } 388 389 TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocales) 390 (JNIEnv* env, jobject clazz, jint fd) { 391 const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( 392 new libtextclassifier3::ScopedMmap(fd)); 393 return libtextclassifier3::GetLocalesFromMmap(env, mmap.get()); 394 } 395 396 TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName) 397 (JNIEnv* env, jobject clazz, jint fd) { 398 const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( 399 new libtextclassifier3::ScopedMmap(fd)); 400 return libtextclassifier3::GetNameFromMmap(env, mmap.get()); 401 } 402 403 TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion) 404 (JNIEnv* env, jobject clazz, jint fd) { 405 const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( 406 new libtextclassifier3::ScopedMmap(fd)); 407 return libtextclassifier3::GetVersionFromMmap(env, mmap.get()); 408 } 409