Home | History | Annotate | Download | only in intents
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "utils/intents/intent-generator.h"
     18 
     19 #include <vector>
     20 
     21 #include "actions/lua-utils.h"
     22 #include "actions/types.h"
     23 #include "annotator/types.h"
     24 #include "utils/base/logging.h"
     25 #include "utils/hash/farmhash.h"
     26 #include "utils/java/jni-base.h"
     27 #include "utils/java/string_utils.h"
     28 #include "utils/lua-utils.h"
     29 #include "utils/strings/stringpiece.h"
     30 #include "utils/strings/substitute.h"
     31 #include "utils/utf8/unicodetext.h"
     32 #include "utils/variant.h"
     33 #include "utils/zlib/zlib.h"
     34 #include "flatbuffers/reflection_generated.h"
     35 
     36 #ifdef __cplusplus
     37 extern "C" {
     38 #endif
     39 #include "lauxlib.h"
     40 #include "lua.h"
     41 #ifdef __cplusplus
     42 }
     43 #endif
     44 
     45 namespace libtextclassifier3 {
     46 namespace {
     47 
     48 static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
     49 static constexpr const char* kHashKey = "hash";
     50 static constexpr const char* kUrlSchemaKey = "url_schema";
     51 static constexpr const char* kUrlHostKey = "url_host";
     52 static constexpr const char* kUrlEncodeKey = "urlencode";
     53 static constexpr const char* kPackageNameKey = "package_name";
     54 static constexpr const char* kDeviceLocaleKey = "device_locales";
     55 static constexpr const char* kFormatKey = "format";
     56 
     57 // An Android specific Lua environment with JNI backed callbacks.
     58 class JniLuaEnvironment : public LuaEnvironment {
     59  public:
     60   JniLuaEnvironment(const Resources& resources, const JniCache* jni_cache,
     61                     const jobject context,
     62                     const std::vector<Locale>& device_locales);
     63   // Environment setup.
     64   bool Initialize();
     65 
     66   // Runs an intent generator snippet.
     67   bool RunIntentGenerator(const std::string& generator_snippet,
     68                           std::vector<RemoteActionTemplate>* remote_actions);
     69 
     70  protected:
     71   virtual void SetupExternalHook();
     72 
     73   int HandleExternalCallback();
     74   int HandleAndroidCallback();
     75   int HandleUserRestrictionsCallback();
     76   int HandleUrlEncode();
     77   int HandleUrlSchema();
     78   int HandleHash();
     79   int HandleFormat();
     80   int HandleAndroidStringResources();
     81   int HandleUrlHost();
     82 
     83   // Checks and retrieves string resources from the model.
     84   bool LookupModelStringResource();
     85 
     86   // Reads and create a RemoteAction result from Lua.
     87   RemoteActionTemplate ReadRemoteActionTemplateResult();
     88 
     89   // Reads the extras from the Lua result.
     90   void ReadExtras(std::map<std::string, Variant>* extra);
     91 
     92   // Reads the intent categories array from a Lua result.
     93   void ReadCategories(std::vector<std::string>* category);
     94 
     95   // Retrieves user manager if not previously done.
     96   bool RetrieveUserManager();
     97 
     98   // Retrieves system resources if not previously done.
     99   bool RetrieveSystemResources();
    100 
    101   // Parse the url string by using Uri.parse from Java.
    102   ScopedLocalRef<jobject> ParseUri(StringPiece url) const;
    103 
    104   // Read remote action templates from lua generator.
    105   int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result);
    106 
    107   const Resources& resources_;
    108   JNIEnv* jenv_;
    109   const JniCache* jni_cache_;
    110   const jobject context_;
    111   std::vector<Locale> device_locales_;
    112 
    113   ScopedGlobalRef<jobject> usermanager_;
    114   // Whether we previously attempted to retrieve the UserManager before.
    115   bool usermanager_retrieved_;
    116 
    117   ScopedGlobalRef<jobject> system_resources_;
    118   // Whether we previously attempted to retrieve the system resources.
    119   bool system_resources_resources_retrieved_;
    120 
    121   // Cached JNI references for Java strings `string` and `android`.
    122   ScopedGlobalRef<jstring> string_;
    123   ScopedGlobalRef<jstring> android_;
    124 };
    125 
    126 JniLuaEnvironment::JniLuaEnvironment(const Resources& resources,
    127                                      const JniCache* jni_cache,
    128                                      const jobject context,
    129                                      const std::vector<Locale>& device_locales)
    130     : resources_(resources),
    131       jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
    132       jni_cache_(jni_cache),
    133       context_(context),
    134       device_locales_(device_locales),
    135       usermanager_(/*object=*/nullptr,
    136                    /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
    137       usermanager_retrieved_(false),
    138       system_resources_(/*object=*/nullptr,
    139                         /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
    140       system_resources_resources_retrieved_(false),
    141       string_(/*object=*/nullptr,
    142               /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
    143       android_(/*object=*/nullptr,
    144                /*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {}
    145 
    146 bool JniLuaEnvironment::Initialize() {
    147   string_ =
    148       MakeGlobalRef(jenv_->NewStringUTF("string"), jenv_, jni_cache_->jvm);
    149   android_ =
    150       MakeGlobalRef(jenv_->NewStringUTF("android"), jenv_, jni_cache_->jvm);
    151   if (string_ == nullptr || android_ == nullptr) {
    152     TC3_LOG(ERROR) << "Could not allocate constant strings references.";
    153     return false;
    154   }
    155   return (RunProtected([this] {
    156             LoadDefaultLibraries();
    157             SetupExternalHook();
    158             lua_setglobal(state_, "external");
    159             return LUA_OK;
    160           }) == LUA_OK);
    161 }
    162 
    163 void JniLuaEnvironment::SetupExternalHook() {
    164   // This exposes an `external` object with the following fields:
    165   //   * entity: the bundle with all information about a classification.
    166   //   * android: callbacks into specific android provided methods.
    167   //   * android.user_restrictions: callbacks to check user permissions.
    168   //   * android.R: callbacks to retrieve string resources.
    169   BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleExternalCallback>(
    170       "external");
    171 
    172   // android
    173   BindTable<JniLuaEnvironment, &JniLuaEnvironment::HandleAndroidCallback>(
    174       "android");
    175   {
    176     // android.user_restrictions
    177     BindTable<JniLuaEnvironment,
    178               &JniLuaEnvironment::HandleUserRestrictionsCallback>(
    179         "user_restrictions");
    180     lua_setfield(state_, /*idx=*/-2, "user_restrictions");
    181 
    182     // android.R
    183     // Callback to access android string resources.
    184     BindTable<JniLuaEnvironment,
    185               &JniLuaEnvironment::HandleAndroidStringResources>("R");
    186     lua_setfield(state_, /*idx=*/-2, "R");
    187   }
    188   lua_setfield(state_, /*idx=*/-2, "android");
    189 }
    190 
    191 int JniLuaEnvironment::HandleExternalCallback() {
    192   const StringPiece key = ReadString(/*index=*/-1);
    193   if (key.Equals(kHashKey)) {
    194     Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleHash>();
    195     return 1;
    196   } else if (key.Equals(kFormatKey)) {
    197     Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleFormat>();
    198     return 1;
    199   } else {
    200     TC3_LOG(ERROR) << "Undefined external access " << key.ToString();
    201     lua_error(state_);
    202     return 0;
    203   }
    204 }
    205 
    206 int JniLuaEnvironment::HandleAndroidCallback() {
    207   const StringPiece key = ReadString(/*index=*/-1);
    208   if (key.Equals(kDeviceLocaleKey)) {
    209     // Provide the locale as table with the individual fields set.
    210     lua_newtable(state_);
    211     for (int i = 0; i < device_locales_.size(); i++) {
    212       // Adjust index to 1-based indexing for Lua.
    213       lua_pushinteger(state_, i + 1);
    214       lua_newtable(state_);
    215       PushString(device_locales_[i].Language());
    216       lua_setfield(state_, -2, "language");
    217       PushString(device_locales_[i].Region());
    218       lua_setfield(state_, -2, "region");
    219       PushString(device_locales_[i].Script());
    220       lua_setfield(state_, -2, "script");
    221       lua_settable(state_, /*idx=*/-3);
    222     }
    223     return 1;
    224   } else if (key.Equals(kPackageNameKey)) {
    225     if (context_ == nullptr) {
    226       TC3_LOG(ERROR) << "Context invalid.";
    227       lua_error(state_);
    228       return 0;
    229     }
    230     ScopedLocalRef<jstring> package_name_str(
    231         static_cast<jstring>(jenv_->CallObjectMethod(
    232             context_, jni_cache_->context_get_package_name)));
    233     if (jni_cache_->ExceptionCheckAndClear()) {
    234       TC3_LOG(ERROR) << "Error calling Context.getPackageName";
    235       lua_error(state_);
    236       return 0;
    237     }
    238     PushString(ToStlString(jenv_, package_name_str.get()));
    239     return 1;
    240   } else if (key.Equals(kUrlEncodeKey)) {
    241     Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlEncode>();
    242     return 1;
    243   } else if (key.Equals(kUrlHostKey)) {
    244     Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlHost>();
    245     return 1;
    246   } else if (key.Equals(kUrlSchemaKey)) {
    247     Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlSchema>();
    248     return 1;
    249   } else {
    250     TC3_LOG(ERROR) << "Undefined android reference " << key.ToString();
    251     lua_error(state_);
    252     return 0;
    253   }
    254 }
    255 
    256 int JniLuaEnvironment::HandleUserRestrictionsCallback() {
    257   if (jni_cache_->usermanager_class == nullptr ||
    258       jni_cache_->usermanager_get_user_restrictions == nullptr) {
    259     // UserManager is only available for API level >= 17 and
    260     // getUserRestrictions only for API level >= 18, so we just return false
    261     // normally here.
    262     lua_pushboolean(state_, false);
    263     return 1;
    264   }
    265 
    266   // Get user manager if not previously retrieved.
    267   if (!RetrieveUserManager()) {
    268     TC3_LOG(ERROR) << "Error retrieving user manager.";
    269     lua_error(state_);
    270     return 0;
    271   }
    272 
    273   ScopedLocalRef<jobject> bundle(jenv_->CallObjectMethod(
    274       usermanager_.get(), jni_cache_->usermanager_get_user_restrictions));
    275   if (jni_cache_->ExceptionCheckAndClear() || bundle == nullptr) {
    276     TC3_LOG(ERROR) << "Error calling getUserRestrictions";
    277     lua_error(state_);
    278     return 0;
    279   }
    280 
    281   const StringPiece key_str = ReadString(/*index=*/-1);
    282   if (key_str.empty()) {
    283     TC3_LOG(ERROR) << "Expected string, got null.";
    284     lua_error(state_);
    285     return 0;
    286   }
    287 
    288   ScopedLocalRef<jstring> key = jni_cache_->ConvertToJavaString(key_str);
    289   if (jni_cache_->ExceptionCheckAndClear() || key == nullptr) {
    290     TC3_LOG(ERROR) << "Expected string, got null.";
    291     lua_error(state_);
    292     return 0;
    293   }
    294   const bool permission = jenv_->CallBooleanMethod(
    295       bundle.get(), jni_cache_->bundle_get_boolean, key.get());
    296   if (jni_cache_->ExceptionCheckAndClear()) {
    297     TC3_LOG(ERROR) << "Error getting bundle value";
    298     lua_pushboolean(state_, false);
    299   } else {
    300     lua_pushboolean(state_, permission);
    301   }
    302   return 1;
    303 }
    304 
    305 int JniLuaEnvironment::HandleUrlEncode() {
    306   const StringPiece input = ReadString(/*index=*/1);
    307   if (input.empty()) {
    308     TC3_LOG(ERROR) << "Expected string, got null.";
    309     lua_error(state_);
    310     return 0;
    311   }
    312 
    313   // Call Java URL encoder.
    314   ScopedLocalRef<jstring> input_str = jni_cache_->ConvertToJavaString(input);
    315   if (jni_cache_->ExceptionCheckAndClear() || input_str == nullptr) {
    316     TC3_LOG(ERROR) << "Expected string, got null.";
    317     lua_error(state_);
    318     return 0;
    319   }
    320   ScopedLocalRef<jstring> encoded_str(
    321       static_cast<jstring>(jenv_->CallStaticObjectMethod(
    322           jni_cache_->urlencoder_class.get(), jni_cache_->urlencoder_encode,
    323           input_str.get(), jni_cache_->string_utf8.get())));
    324   if (jni_cache_->ExceptionCheckAndClear()) {
    325     TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
    326     lua_error(state_);
    327     return 0;
    328   }
    329   PushString(ToStlString(jenv_, encoded_str.get()));
    330   return 1;
    331 }
    332 
    333 ScopedLocalRef<jobject> JniLuaEnvironment::ParseUri(StringPiece url) const {
    334   if (url.empty()) {
    335     return nullptr;
    336   }
    337 
    338   // Call to Java URI parser.
    339   ScopedLocalRef<jstring> url_str = jni_cache_->ConvertToJavaString(url);
    340   if (jni_cache_->ExceptionCheckAndClear() || url_str == nullptr) {
    341     TC3_LOG(ERROR) << "Expected string, got null";
    342     return nullptr;
    343   }
    344 
    345   // Try to parse uri and get scheme.
    346   ScopedLocalRef<jobject> uri(jenv_->CallStaticObjectMethod(
    347       jni_cache_->uri_class.get(), jni_cache_->uri_parse, url_str.get()));
    348   if (jni_cache_->ExceptionCheckAndClear() || uri == nullptr) {
    349     TC3_LOG(ERROR) << "Error calling Uri.parse";
    350   }
    351   return uri;
    352 }
    353 
    354 int JniLuaEnvironment::HandleUrlSchema() {
    355   StringPiece url = ReadString(/*index=*/1);
    356 
    357   ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
    358   if (parsed_uri == nullptr) {
    359     lua_error(state_);
    360     return 0;
    361   }
    362 
    363   ScopedLocalRef<jstring> scheme_str(static_cast<jstring>(
    364       jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_scheme)));
    365   if (jni_cache_->ExceptionCheckAndClear()) {
    366     TC3_LOG(ERROR) << "Error calling Uri.getScheme";
    367     lua_error(state_);
    368     return 0;
    369   }
    370   if (scheme_str == nullptr) {
    371     lua_pushnil(state_);
    372   } else {
    373     PushString(ToStlString(jenv_, scheme_str.get()));
    374   }
    375   return 1;
    376 }
    377 
    378 int JniLuaEnvironment::HandleUrlHost() {
    379   StringPiece url = ReadString(/*index=*/-1);
    380 
    381   ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
    382   if (parsed_uri == nullptr) {
    383     lua_error(state_);
    384     return 0;
    385   }
    386 
    387   ScopedLocalRef<jstring> host_str(static_cast<jstring>(
    388       jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_host)));
    389   if (jni_cache_->ExceptionCheckAndClear()) {
    390     TC3_LOG(ERROR) << "Error calling Uri.getHost";
    391     lua_error(state_);
    392     return 0;
    393   }
    394   if (host_str == nullptr) {
    395     lua_pushnil(state_);
    396   } else {
    397     PushString(ToStlString(jenv_, host_str.get()));
    398   }
    399   return 1;
    400 }
    401 
    402 int JniLuaEnvironment::HandleHash() {
    403   const StringPiece input = ReadString(/*index=*/-1);
    404   lua_pushinteger(state_, tc3farmhash::Hash32(input.data(), input.length()));
    405   return 1;
    406 }
    407 
    408 int JniLuaEnvironment::HandleFormat() {
    409   const int num_args = lua_gettop(state_);
    410   std::vector<StringPiece> args(num_args - 1);
    411   for (int i = 0; i < num_args - 1; i++) {
    412     args[i] = ReadString(/*index=*/i + 2);
    413   }
    414   PushString(strings::Substitute(ReadString(/*index=*/1), args));
    415   return 1;
    416 }
    417 
    418 bool JniLuaEnvironment::LookupModelStringResource() {
    419   // Handle only lookup by name.
    420   if (lua_type(state_, 2) != LUA_TSTRING) {
    421     return false;
    422   }
    423 
    424   const StringPiece resource_name = ReadString(/*index=*/-1);
    425   std::string resource_content;
    426   if (!resources_.GetResourceContent(device_locales_, resource_name,
    427                                      &resource_content)) {
    428     // Resource cannot be provided by the model.
    429     return false;
    430   }
    431 
    432   PushString(resource_content);
    433   return true;
    434 }
    435 
    436 int JniLuaEnvironment::HandleAndroidStringResources() {
    437   // Check whether the requested resource can be served from the model data.
    438   if (LookupModelStringResource()) {
    439     return 1;
    440   }
    441 
    442   // Get system resources if not previously retrieved.
    443   if (!RetrieveSystemResources()) {
    444     TC3_LOG(ERROR) << "Error retrieving system resources.";
    445     lua_error(state_);
    446     return 0;
    447   }
    448 
    449   int resource_id;
    450   switch (lua_type(state_, -1)) {
    451     case LUA_TNUMBER:
    452       resource_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
    453       break;
    454     case LUA_TSTRING: {
    455       const StringPiece resource_name_str = ReadString(/*index=*/-1);
    456       if (resource_name_str.empty()) {
    457         TC3_LOG(ERROR) << "No resource name provided.";
    458         lua_error(state_);
    459         return 0;
    460       }
    461       ScopedLocalRef<jstring> resource_name =
    462           jni_cache_->ConvertToJavaString(resource_name_str);
    463       if (resource_name == nullptr) {
    464         TC3_LOG(ERROR) << "Invalid resource name.";
    465         lua_error(state_);
    466         return 0;
    467       }
    468       resource_id = jenv_->CallIntMethod(
    469           system_resources_.get(), jni_cache_->resources_get_identifier,
    470           resource_name.get(), string_.get(), android_.get());
    471       if (jni_cache_->ExceptionCheckAndClear()) {
    472         TC3_LOG(ERROR) << "Error calling getIdentifier.";
    473         lua_error(state_);
    474         return 0;
    475       }
    476       break;
    477     }
    478     default:
    479       TC3_LOG(ERROR) << "Unexpected type for resource lookup.";
    480       lua_error(state_);
    481       return 0;
    482   }
    483   if (resource_id == 0) {
    484     TC3_LOG(ERROR) << "Resource not found.";
    485     lua_pushnil(state_);
    486     return 1;
    487   }
    488   ScopedLocalRef<jstring> resource_str(static_cast<jstring>(
    489       jenv_->CallObjectMethod(system_resources_.get(),
    490                               jni_cache_->resources_get_string, resource_id)));
    491   if (jni_cache_->ExceptionCheckAndClear()) {
    492     TC3_LOG(ERROR) << "Error calling getString.";
    493     lua_error(state_);
    494     return 0;
    495   }
    496   if (resource_str == nullptr) {
    497     lua_pushnil(state_);
    498   } else {
    499     PushString(ToStlString(jenv_, resource_str.get()));
    500   }
    501   return 1;
    502 }
    503 
    504 bool JniLuaEnvironment::RetrieveSystemResources() {
    505   if (system_resources_resources_retrieved_) {
    506     return (system_resources_ != nullptr);
    507   }
    508   system_resources_resources_retrieved_ = true;
    509   jobject system_resources_ref = jenv_->CallStaticObjectMethod(
    510       jni_cache_->resources_class.get(), jni_cache_->resources_get_system);
    511   if (jni_cache_->ExceptionCheckAndClear()) {
    512     TC3_LOG(ERROR) << "Error calling getSystem.";
    513     return false;
    514   }
    515   system_resources_ =
    516       MakeGlobalRef(system_resources_ref, jenv_, jni_cache_->jvm);
    517   return (system_resources_ != nullptr);
    518 }
    519 
    520 bool JniLuaEnvironment::RetrieveUserManager() {
    521   if (context_ == nullptr) {
    522     return false;
    523   }
    524   if (usermanager_retrieved_) {
    525     return (usermanager_ != nullptr);
    526   }
    527   usermanager_retrieved_ = true;
    528   ScopedLocalRef<jstring> service(jenv_->NewStringUTF("user"));
    529   jobject usermanager_ref = jenv_->CallObjectMethod(
    530       context_, jni_cache_->context_get_system_service, service.get());
    531   if (jni_cache_->ExceptionCheckAndClear()) {
    532     TC3_LOG(ERROR) << "Error calling getSystemService.";
    533     return false;
    534   }
    535   usermanager_ = MakeGlobalRef(usermanager_ref, jenv_, jni_cache_->jvm);
    536   return (usermanager_ != nullptr);
    537 }
    538 
    539 RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() {
    540   RemoteActionTemplate result;
    541   // Read intent template.
    542   lua_pushnil(state_);
    543   while (lua_next(state_, /*idx=*/-2)) {
    544     const StringPiece key = ReadString(/*index=*/-2);
    545     if (key.Equals("title_without_entity")) {
    546       result.title_without_entity = ReadString(/*index=*/-1).ToString();
    547     } else if (key.Equals("title_with_entity")) {
    548       result.title_with_entity = ReadString(/*index=*/-1).ToString();
    549     } else if (key.Equals("description")) {
    550       result.description = ReadString(/*index=*/-1).ToString();
    551     } else if (key.Equals("description_with_app_name")) {
    552       result.description_with_app_name = ReadString(/*index=*/-1).ToString();
    553     } else if (key.Equals("action")) {
    554       result.action = ReadString(/*index=*/-1).ToString();
    555     } else if (key.Equals("data")) {
    556       result.data = ReadString(/*index=*/-1).ToString();
    557     } else if (key.Equals("type")) {
    558       result.type = ReadString(/*index=*/-1).ToString();
    559     } else if (key.Equals("flags")) {
    560       result.flags = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
    561     } else if (key.Equals("package_name")) {
    562       result.package_name = ReadString(/*index=*/-1).ToString();
    563     } else if (key.Equals("request_code")) {
    564       result.request_code = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
    565     } else if (key.Equals("category")) {
    566       ReadCategories(&result.category);
    567     } else if (key.Equals("extra")) {
    568       ReadExtras(&result.extra);
    569     } else {
    570       TC3_LOG(INFO) << "Unknown entry: " << key.ToString();
    571     }
    572     lua_pop(state_, 1);
    573   }
    574   lua_pop(state_, 1);
    575   return result;
    576 }
    577 
    578 void JniLuaEnvironment::ReadCategories(std::vector<std::string>* category) {
    579   // Read category array.
    580   if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
    581     TC3_LOG(ERROR) << "Expected categories table, got: "
    582                    << lua_type(state_, /*idx=*/-1);
    583     lua_pop(state_, 1);
    584     return;
    585   }
    586   lua_pushnil(state_);
    587   while (lua_next(state_, /*idx=*/-2)) {
    588     category->push_back(ReadString(/*index=*/-1).ToString());
    589     lua_pop(state_, 1);
    590   }
    591 }
    592 
    593 void JniLuaEnvironment::ReadExtras(std::map<std::string, Variant>* extra) {
    594   if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
    595     TC3_LOG(ERROR) << "Expected extras table, got: "
    596                    << lua_type(state_, /*idx=*/-1);
    597     lua_pop(state_, 1);
    598     return;
    599   }
    600   lua_pushnil(state_);
    601   while (lua_next(state_, /*idx=*/-2)) {
    602     // Each entry is a table specifying name and value.
    603     // The value is specified via a type specific field as Lua doesn't allow
    604     // to easily distinguish between different number types.
    605     if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
    606       TC3_LOG(ERROR) << "Expected a table for an extra, got: "
    607                      << lua_type(state_, /*idx=*/-1);
    608       lua_pop(state_, 1);
    609       return;
    610     }
    611     std::string name;
    612     Variant value;
    613 
    614     lua_pushnil(state_);
    615     while (lua_next(state_, /*idx=*/-2)) {
    616       const StringPiece key = ReadString(/*index=*/-2);
    617       if (key.Equals("name")) {
    618         name = ReadString(/*index=*/-1).ToString();
    619       } else if (key.Equals("int_value")) {
    620         value = Variant(static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
    621       } else if (key.Equals("long_value")) {
    622         value = Variant(static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
    623       } else if (key.Equals("float_value")) {
    624         value = Variant(static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
    625       } else if (key.Equals("bool_value")) {
    626         value = Variant(static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
    627       } else if (key.Equals("string_value")) {
    628         value = Variant(ReadString(/*index=*/-1).ToString());
    629       } else {
    630         TC3_LOG(INFO) << "Unknown extra field: " << key.ToString();
    631       }
    632       lua_pop(state_, 1);
    633     }
    634     if (!name.empty()) {
    635       (*extra)[name] = value;
    636     } else {
    637       TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
    638     }
    639     lua_pop(state_, 1);
    640   }
    641 }
    642 
    643 int JniLuaEnvironment::ReadRemoteActionTemplates(
    644     std::vector<RemoteActionTemplate>* result) {
    645   // Read result.
    646   if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
    647     TC3_LOG(ERROR) << "Unexpected result for snippet: " << lua_type(state_, -1);
    648     lua_error(state_);
    649     return LUA_ERRRUN;
    650   }
    651 
    652   // Read remote action templates array.
    653   lua_pushnil(state_);
    654   while (lua_next(state_, /*idx=*/-2)) {
    655     if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
    656       TC3_LOG(ERROR) << "Expected intent table, got: "
    657                      << lua_type(state_, /*idx=*/-1);
    658       lua_pop(state_, 1);
    659       continue;
    660     }
    661     result->push_back(ReadRemoteActionTemplateResult());
    662   }
    663   lua_pop(state_, /*n=*/1);
    664   return LUA_OK;
    665 }
    666 
    667 bool JniLuaEnvironment::RunIntentGenerator(
    668     const std::string& generator_snippet,
    669     std::vector<RemoteActionTemplate>* remote_actions) {
    670   int status;
    671   status = luaL_loadbuffer(state_, generator_snippet.data(),
    672                            generator_snippet.size(),
    673                            /*name=*/nullptr);
    674   if (status != LUA_OK) {
    675     TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
    676     return false;
    677   }
    678   status = lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0);
    679   if (status != LUA_OK) {
    680     TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
    681     return false;
    682   }
    683   if (RunProtected(
    684           [this, remote_actions] {
    685             return ReadRemoteActionTemplates(remote_actions);
    686           },
    687           /*num_args=*/1) != LUA_OK) {
    688     TC3_LOG(ERROR) << "Could not read results.";
    689     return false;
    690   }
    691   // Check that we correctly cleaned-up the state.
    692   const int stack_size = lua_gettop(state_);
    693   if (stack_size > 0) {
    694     TC3_LOG(ERROR) << "Unexpected stack size.";
    695     lua_settop(state_, 0);
    696     return false;
    697   }
    698   return true;
    699 }
    700 
    701 // Lua environment for classfication result intent generation.
    702 class AnnotatorJniEnvironment : public JniLuaEnvironment {
    703  public:
    704   AnnotatorJniEnvironment(const Resources& resources, const JniCache* jni_cache,
    705                           const jobject context,
    706                           const std::vector<Locale>& device_locales,
    707                           const std::string& entity_text,
    708                           const ClassificationResult& classification,
    709                           const int64 reference_time_ms_utc,
    710                           const reflection::Schema* entity_data_schema)
    711       : JniLuaEnvironment(resources, jni_cache, context, device_locales),
    712         entity_text_(entity_text),
    713         classification_(classification),
    714         reference_time_ms_utc_(reference_time_ms_utc),
    715         entity_data_schema_(entity_data_schema) {}
    716 
    717  protected:
    718   void SetupExternalHook() override {
    719     JniLuaEnvironment::SetupExternalHook();
    720     lua_pushinteger(state_, reference_time_ms_utc_);
    721     lua_setfield(state_, /*idx=*/-2, kReferenceTimeUsecKey);
    722 
    723     PushAnnotation(classification_, entity_text_, entity_data_schema_, this);
    724     lua_setfield(state_, /*idx=*/-2, "entity");
    725   }
    726 
    727   const std::string& entity_text_;
    728   const ClassificationResult& classification_;
    729   const int64 reference_time_ms_utc_;
    730 
    731   // Reflection schema data.
    732   const reflection::Schema* const entity_data_schema_;
    733 };
    734 
    735 // Lua environment for actions intent generation.
    736 class ActionsJniLuaEnvironment : public JniLuaEnvironment {
    737  public:
    738   ActionsJniLuaEnvironment(
    739       const Resources& resources, const JniCache* jni_cache,
    740       const jobject context, const std::vector<Locale>& device_locales,
    741       const Conversation& conversation, const ActionSuggestion& action,
    742       const reflection::Schema* actions_entity_data_schema,
    743       const reflection::Schema* annotations_entity_data_schema)
    744       : JniLuaEnvironment(resources, jni_cache, context, device_locales),
    745         conversation_(conversation),
    746         action_(action),
    747         annotation_iterator_(annotations_entity_data_schema, this),
    748         conversation_iterator_(annotations_entity_data_schema, this),
    749         entity_data_schema_(actions_entity_data_schema) {}
    750 
    751  protected:
    752   void SetupExternalHook() override {
    753     JniLuaEnvironment::SetupExternalHook();
    754     conversation_iterator_.NewIterator("conversation", &conversation_.messages,
    755                                        state_);
    756     lua_setfield(state_, /*idx=*/-2, "conversation");
    757 
    758     PushAction(action_, entity_data_schema_, annotation_iterator_, this);
    759     lua_setfield(state_, /*idx=*/-2, "entity");
    760   }
    761 
    762   const Conversation& conversation_;
    763   const ActionSuggestion& action_;
    764   const AnnotationIterator<ActionSuggestionAnnotation> annotation_iterator_;
    765   const ConversationIterator conversation_iterator_;
    766   const reflection::Schema* entity_data_schema_;
    767 };
    768 
    769 }  // namespace
    770 
    771 std::unique_ptr<IntentGenerator> IntentGenerator::Create(
    772     const IntentFactoryModel* options, const ResourcePool* resources,
    773     const std::shared_ptr<JniCache>& jni_cache) {
    774   std::unique_ptr<IntentGenerator> intent_generator(
    775       new IntentGenerator(options, resources, jni_cache));
    776 
    777   if (options == nullptr || options->generator() == nullptr) {
    778     TC3_LOG(ERROR) << "No intent generator options.";
    779     return nullptr;
    780   }
    781 
    782   std::unique_ptr<ZlibDecompressor> zlib_decompressor =
    783       ZlibDecompressor::Instance();
    784   if (!zlib_decompressor) {
    785     TC3_LOG(ERROR) << "Cannot initialize decompressor.";
    786     return nullptr;
    787   }
    788 
    789   for (const IntentFactoryModel_::IntentGenerator* generator :
    790        *options->generator()) {
    791     std::string lua_template_generator;
    792     if (!zlib_decompressor->MaybeDecompressOptionallyCompressedBuffer(
    793             generator->lua_template_generator(),
    794             generator->compressed_lua_template_generator(),
    795             &lua_template_generator)) {
    796       TC3_LOG(ERROR) << "Could not decompress generator template.";
    797       return nullptr;
    798     }
    799 
    800     std::string lua_code = lua_template_generator;
    801     if (options->precompile_generators()) {
    802       if (!Compile(lua_template_generator, &lua_code)) {
    803         TC3_LOG(ERROR) << "Could not precompile generator template.";
    804         return nullptr;
    805       }
    806     }
    807 
    808     intent_generator->generators_[generator->type()->str()] = lua_code;
    809   }
    810 
    811   return intent_generator;
    812 }
    813 
    814 std::vector<Locale> IntentGenerator::ParseDeviceLocales(
    815     const jstring device_locales) const {
    816   if (device_locales == nullptr) {
    817     TC3_LOG(ERROR) << "No locales provided.";
    818     return {};
    819   }
    820   ScopedStringChars locales_str =
    821       GetScopedStringChars(jni_cache_->GetEnv(), device_locales);
    822   if (locales_str == nullptr) {
    823     TC3_LOG(ERROR) << "Cannot retrieve provided locales.";
    824     return {};
    825   }
    826   std::vector<Locale> locales;
    827   if (!ParseLocales(reinterpret_cast<const char*>(locales_str.get()),
    828                     &locales)) {
    829     TC3_LOG(ERROR) << "Cannot parse locales.";
    830     return {};
    831   }
    832   return locales;
    833 }
    834 
    835 bool IntentGenerator::GenerateIntents(
    836     const jstring device_locales, const ClassificationResult& classification,
    837     const int64 reference_time_ms_utc, const std::string& text,
    838     const CodepointSpan selection_indices, const jobject context,
    839     const reflection::Schema* annotations_entity_data_schema,
    840     std::vector<RemoteActionTemplate>* remote_actions) const {
    841   if (options_ == nullptr) {
    842     return false;
    843   }
    844 
    845   // Retrieve generator for specified entity.
    846   auto it = generators_.find(classification.collection);
    847   if (it == generators_.end()) {
    848     return true;
    849   }
    850 
    851   const std::string entity_text =
    852       UTF8ToUnicodeText(text, /*do_copy=*/false)
    853           .UTF8Substring(selection_indices.first, selection_indices.second);
    854 
    855   std::unique_ptr<AnnotatorJniEnvironment> interpreter(
    856       new AnnotatorJniEnvironment(
    857           resources_, jni_cache_.get(), context,
    858           ParseDeviceLocales(device_locales), entity_text, classification,
    859           reference_time_ms_utc, annotations_entity_data_schema));
    860 
    861   if (!interpreter->Initialize()) {
    862     TC3_LOG(ERROR) << "Could not create Lua interpreter.";
    863     return false;
    864   }
    865 
    866   return interpreter->RunIntentGenerator(it->second, remote_actions);
    867 }
    868 
    869 bool IntentGenerator::GenerateIntents(
    870     const jstring device_locales, const ActionSuggestion& action,
    871     const Conversation& conversation, const jobject context,
    872     const reflection::Schema* annotations_entity_data_schema,
    873     const reflection::Schema* actions_entity_data_schema,
    874     std::vector<RemoteActionTemplate>* remote_actions) const {
    875   if (options_ == nullptr) {
    876     return false;
    877   }
    878 
    879   // Retrieve generator for specified action.
    880   auto it = generators_.find(action.type);
    881   if (it == generators_.end()) {
    882     return true;
    883   }
    884 
    885   std::unique_ptr<ActionsJniLuaEnvironment> interpreter(
    886       new ActionsJniLuaEnvironment(
    887           resources_, jni_cache_.get(), context,
    888           ParseDeviceLocales(device_locales), conversation, action,
    889           actions_entity_data_schema, annotations_entity_data_schema));
    890 
    891   if (!interpreter->Initialize()) {
    892     TC3_LOG(ERROR) << "Could not create Lua interpreter.";
    893     return false;
    894   }
    895 
    896   return interpreter->RunIntentGenerator(it->second, remote_actions);
    897 }
    898 
    899 }  // namespace libtextclassifier3
    900