1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "actions/lua-ranker.h" 18 #include "utils/base/logging.h" 19 #include "utils/lua-utils.h" 20 21 #ifdef __cplusplus 22 extern "C" { 23 #endif 24 #include "lauxlib.h" 25 #include "lualib.h" 26 #ifdef __cplusplus 27 } 28 #endif 29 30 namespace libtextclassifier3 { 31 32 std::unique_ptr<ActionsSuggestionsLuaRanker> 33 ActionsSuggestionsLuaRanker::Create( 34 const Conversation& conversation, const std::string& ranker_code, 35 const reflection::Schema* entity_data_schema, 36 const reflection::Schema* annotations_entity_data_schema, 37 ActionsSuggestionsResponse* response) { 38 auto ranker = std::unique_ptr<ActionsSuggestionsLuaRanker>( 39 new ActionsSuggestionsLuaRanker( 40 conversation, ranker_code, entity_data_schema, 41 annotations_entity_data_schema, response)); 42 if (!ranker->Initialize()) { 43 TC3_LOG(ERROR) << "Could not initialize lua environment for ranker."; 44 return nullptr; 45 } 46 return ranker; 47 } 48 49 bool ActionsSuggestionsLuaRanker::Initialize() { 50 return RunProtected([this] { 51 LoadDefaultLibraries(); 52 53 // Expose generated actions. 54 actions_iterator_.NewIterator("actions", &response_->actions, 55 state_); 56 lua_setglobal(state_, "actions"); 57 58 // Expose conversation message stream. 59 conversation_iterator_.NewIterator("messages", 60 &conversation_.messages, state_); 61 lua_setglobal(state_, "messages"); 62 return LUA_OK; 63 }) == LUA_OK; 64 } 65 66 int ActionsSuggestionsLuaRanker::ReadActionsRanking() { 67 if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) { 68 TC3_LOG(ERROR) << "Expected actions table, got: " 69 << lua_type(state_, /*idx=*/-1); 70 lua_pop(state_, 1); 71 lua_error(state_); 72 return LUA_ERRRUN; 73 } 74 std::vector<ActionSuggestion> ranked_actions; 75 lua_pushnil(state_); 76 while (lua_next(state_, /*idx=*/-2)) { 77 const int action_id = 78 static_cast<int>(lua_tointeger(state_, /*idx=*/-1)) - 1; 79 lua_pop(state_, 1); 80 if (action_id < 0 || action_id >= response_->actions.size()) { 81 TC3_LOG(ERROR) << "Invalid action index: " << action_id; 82 lua_error(state_); 83 return LUA_ERRRUN; 84 } 85 ranked_actions.push_back(response_->actions[action_id]); 86 } 87 lua_pop(state_, 1); 88 response_->actions = ranked_actions; 89 return LUA_OK; 90 } 91 92 bool ActionsSuggestionsLuaRanker::RankActions() { 93 if (response_->actions.empty()) { 94 // Nothing to do. 95 return true; 96 } 97 98 if (luaL_loadbuffer(state_, ranker_code_.data(), ranker_code_.size(), 99 /*name=*/nullptr) != LUA_OK) { 100 TC3_LOG(ERROR) << "Could not load compiled ranking snippet."; 101 return false; 102 } 103 104 if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) { 105 TC3_LOG(ERROR) << "Could not run ranking snippet."; 106 return false; 107 } 108 109 if (RunProtected([this] { return ReadActionsRanking(); }, /*num_args=*/1) != 110 LUA_OK) { 111 TC3_LOG(ERROR) << "Could not read lua result."; 112 return false; 113 } 114 return true; 115 } 116 117 } // namespace libtextclassifier3 118