Home | History | Annotate | Download | only in actions
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "actions/lua-ranker.h"
     18 #include "utils/base/logging.h"
     19 #include "utils/lua-utils.h"
     20 
     21 #ifdef __cplusplus
     22 extern "C" {
     23 #endif
     24 #include "lauxlib.h"
     25 #include "lualib.h"
     26 #ifdef __cplusplus
     27 }
     28 #endif
     29 
     30 namespace libtextclassifier3 {
     31 
     32 std::unique_ptr<ActionsSuggestionsLuaRanker>
     33 ActionsSuggestionsLuaRanker::Create(
     34     const Conversation& conversation, const std::string& ranker_code,
     35     const reflection::Schema* entity_data_schema,
     36     const reflection::Schema* annotations_entity_data_schema,
     37     ActionsSuggestionsResponse* response) {
     38   auto ranker = std::unique_ptr<ActionsSuggestionsLuaRanker>(
     39       new ActionsSuggestionsLuaRanker(
     40           conversation, ranker_code, entity_data_schema,
     41           annotations_entity_data_schema, response));
     42   if (!ranker->Initialize()) {
     43     TC3_LOG(ERROR) << "Could not initialize lua environment for ranker.";
     44     return nullptr;
     45   }
     46   return ranker;
     47 }
     48 
     49 bool ActionsSuggestionsLuaRanker::Initialize() {
     50   return RunProtected([this] {
     51            LoadDefaultLibraries();
     52 
     53            // Expose generated actions.
     54            actions_iterator_.NewIterator("actions", &response_->actions,
     55                                          state_);
     56            lua_setglobal(state_, "actions");
     57 
     58            // Expose conversation message stream.
     59            conversation_iterator_.NewIterator("messages",
     60                                               &conversation_.messages, state_);
     61            lua_setglobal(state_, "messages");
     62            return LUA_OK;
     63          }) == LUA_OK;
     64 }
     65 
     66 int ActionsSuggestionsLuaRanker::ReadActionsRanking() {
     67   if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
     68     TC3_LOG(ERROR) << "Expected actions table, got: "
     69                    << lua_type(state_, /*idx=*/-1);
     70     lua_pop(state_, 1);
     71     lua_error(state_);
     72     return LUA_ERRRUN;
     73   }
     74   std::vector<ActionSuggestion> ranked_actions;
     75   lua_pushnil(state_);
     76   while (lua_next(state_, /*idx=*/-2)) {
     77     const int action_id =
     78         static_cast<int>(lua_tointeger(state_, /*idx=*/-1)) - 1;
     79     lua_pop(state_, 1);
     80     if (action_id < 0 || action_id >= response_->actions.size()) {
     81       TC3_LOG(ERROR) << "Invalid action index: " << action_id;
     82       lua_error(state_);
     83       return LUA_ERRRUN;
     84     }
     85     ranked_actions.push_back(response_->actions[action_id]);
     86   }
     87   lua_pop(state_, 1);
     88   response_->actions = ranked_actions;
     89   return LUA_OK;
     90 }
     91 
     92 bool ActionsSuggestionsLuaRanker::RankActions() {
     93   if (response_->actions.empty()) {
     94     // Nothing to do.
     95     return true;
     96   }
     97 
     98   if (luaL_loadbuffer(state_, ranker_code_.data(), ranker_code_.size(),
     99                       /*name=*/nullptr) != LUA_OK) {
    100     TC3_LOG(ERROR) << "Could not load compiled ranking snippet.";
    101     return false;
    102   }
    103 
    104   if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
    105     TC3_LOG(ERROR) << "Could not run ranking snippet.";
    106     return false;
    107   }
    108 
    109   if (RunProtected([this] { return ReadActionsRanking(); }, /*num_args=*/1) !=
    110       LUA_OK) {
    111     TC3_LOG(ERROR) << "Could not read lua result.";
    112     return false;
    113   }
    114   return true;
    115 }
    116 
    117 }  // namespace libtextclassifier3
    118