1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "actions/lua-actions.h" 18 19 #include <map> 20 #include <string> 21 22 #include "actions/test_utils.h" 23 #include "actions/types.h" 24 #include "utils/tflite-model-executor.h" 25 #include "gmock/gmock.h" 26 #include "gtest/gtest.h" 27 28 namespace libtextclassifier3 { 29 namespace { 30 31 MATCHER_P2(IsAction, type, response_text, "") { 32 return testing::Value(arg.type, type) && 33 testing::Value(arg.response_text, response_text); 34 } 35 36 MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); } 37 38 TEST(LuaActions, SimpleAction) { 39 Conversation conversation; 40 const std::string test_snippet = R"( 41 return {{ type = "test_action" }} 42 )"; 43 std::vector<ActionSuggestion> actions; 44 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions( 45 test_snippet, conversation, 46 /*model_executor=*/nullptr, 47 /*model_spec=*/nullptr, 48 /*interpreter=*/nullptr, 49 /*actions_entity_data_schema=*/nullptr, 50 /*annotations_entity_data_schema=*/nullptr) 51 ->SuggestActions(&actions)); 52 EXPECT_THAT(actions, 53 testing::ElementsAreArray({IsActionType("test_action")})); 54 } 55 56 TEST(LuaActions, ConversationActions) { 57 Conversation conversation; 58 conversation.messages.push_back({/*user_id=*/0, "hello there!"}); 59 conversation.messages.push_back({/*user_id=*/1, "general kenobi!"}); 60 const std::string test_snippet = R"( 61 local actions = {} 62 for i, message in pairs(messages) do 63 if i < #messages then 64 if message.text == "hello there!" and 65 messages[i+1].text == "general kenobi!" then 66 table.insert(actions, { 67 type = "text_reply", 68 response_text = "you are a bold one!" 69 }) 70 end 71 if message.text == "i am the senate!" and 72 messages[i+1].text == "not yet!" then 73 table.insert(actions, { 74 type = "text_reply", 75 response_text = "it's treason then" 76 }) 77 end 78 end 79 end 80 return actions; 81 )"; 82 std::vector<ActionSuggestion> actions; 83 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions( 84 test_snippet, conversation, 85 /*model_executor=*/nullptr, 86 /*model_spec=*/nullptr, 87 /*interpreter=*/nullptr, 88 /*actions_entity_data_schema=*/nullptr, 89 /*annotations_entity_data_schema=*/nullptr) 90 ->SuggestActions(&actions)); 91 EXPECT_THAT(actions, testing::ElementsAreArray( 92 {IsAction("text_reply", "you are a bold one!")})); 93 } 94 95 TEST(LuaActions, SimpleModelAction) { 96 Conversation conversation; 97 const std::string test_snippet = R"( 98 if #model.actions_scores == 0 then 99 return {{ type = "test_action" }} 100 end 101 return {} 102 )"; 103 std::vector<ActionSuggestion> actions; 104 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions( 105 test_snippet, conversation, 106 /*model_executor=*/nullptr, 107 /*model_spec=*/nullptr, 108 /*interpreter=*/nullptr, 109 /*actions_entity_data_schema=*/nullptr, 110 /*annotations_entity_data_schema=*/nullptr) 111 ->SuggestActions(&actions)); 112 EXPECT_THAT(actions, 113 testing::ElementsAreArray({IsActionType("test_action")})); 114 } 115 116 TEST(LuaActions, AnnotationActions) { 117 AnnotatedSpan annotation; 118 annotation.span = {11, 15}; 119 annotation.classification = {ClassificationResult("address", 1.0)}; 120 Conversation conversation = {{{/*user_id=*/1, "are you at home?", 121 /*reference_time_ms_utc=*/0, 122 /*reference_timezone=*/"Europe/Zurich", 123 /*annotations=*/{annotation}, 124 /*locales=*/"en"}}}; 125 const std::string test_snippet = R"( 126 local actions = {} 127 local last_message = messages[#messages] 128 for i, annotation in pairs(last_message.annotation) do 129 if #annotation.classification > 0 then 130 if annotation.classification[1].collection == "address" then 131 local text = string.sub(last_message.text, 132 annotation.span["begin"] + 1, 133 annotation.span["end"]) 134 table.insert(actions, { 135 type = "text_reply", 136 response_text = "i am at " .. text, 137 annotation = {{ 138 name = "location", 139 span = { 140 text = text 141 }, 142 entity = annotation.classification[1] 143 }}, 144 }) 145 end 146 end 147 end 148 return actions; 149 )"; 150 std::vector<ActionSuggestion> actions; 151 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions( 152 test_snippet, conversation, 153 /*model_executor=*/nullptr, 154 /*model_spec=*/nullptr, 155 /*interpreter=*/nullptr, 156 /*actions_entity_data_schema=*/nullptr, 157 /*annotations_entity_data_schema=*/nullptr) 158 ->SuggestActions(&actions)); 159 EXPECT_THAT(actions, testing::ElementsAreArray( 160 {IsAction("text_reply", "i am at home")})); 161 EXPECT_EQ("address", actions[0].annotations[0].entity.collection); 162 } 163 164 TEST(LuaActions, EntityData) { 165 std::string test_schema = TestEntityDataSchema(); 166 Conversation conversation = {{{/*user_id=*/1, "hello there"}}}; 167 const std::string test_snippet = R"( 168 return {{ 169 type = "test", 170 entity = { 171 greeting = "hello", 172 location = "there", 173 person = "Kenobi", 174 }, 175 }}; 176 )"; 177 std::vector<ActionSuggestion> actions; 178 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions( 179 test_snippet, conversation, 180 /*model_executor=*/nullptr, 181 /*model_spec=*/nullptr, 182 /*interpreter=*/nullptr, 183 /*actions_entity_data_schema=*/ 184 flatbuffers::GetRoot<reflection::Schema>(test_schema.data()), 185 /*annotations_entity_data_schema=*/nullptr) 186 ->SuggestActions(&actions)); 187 EXPECT_THAT(actions, testing::SizeIs(1)); 188 EXPECT_EQ("test", actions.front().type); 189 const flatbuffers::Table* entity = 190 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>( 191 actions.front().serialized_entity_data.data())); 192 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(), 193 "hello"); 194 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(), 195 "there"); 196 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(), 197 "Kenobi"); 198 } 199 200 } // namespace 201 } // namespace libtextclassifier3 202