1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "actions/lua-ranker.h" 18 19 #include <string> 20 21 #include "actions/types.h" 22 #include "utils/flatbuffers.h" 23 #include "gmock/gmock.h" 24 #include "gtest/gtest.h" 25 26 namespace libtextclassifier3 { 27 namespace { 28 29 MATCHER_P2(IsAction, type, response_text, "") { 30 return testing::Value(arg.type, type) && 31 testing::Value(arg.response_text, response_text); 32 } 33 34 MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); } 35 36 std::string TestEntitySchema() { 37 // Create fake entity data schema meta data. 38 // Cannot use object oriented API here as that is not available for the 39 // reflection schema. 40 flatbuffers::FlatBufferBuilder schema_builder; 41 std::vector<flatbuffers::Offset<reflection::Field>> fields = { 42 reflection::CreateField( 43 schema_builder, 44 /*name=*/schema_builder.CreateString("test"), 45 /*type=*/ 46 reflection::CreateType(schema_builder, 47 /*base_type=*/reflection::String), 48 /*id=*/0, 49 /*offset=*/4)}; 50 std::vector<flatbuffers::Offset<reflection::Enum>> enums; 51 std::vector<flatbuffers::Offset<reflection::Object>> objects = { 52 reflection::CreateObject( 53 schema_builder, 54 /*name=*/schema_builder.CreateString("EntityData"), 55 /*fields=*/ 56 schema_builder.CreateVectorOfSortedTables(&fields))}; 57 schema_builder.Finish(reflection::CreateSchema( 58 schema_builder, schema_builder.CreateVectorOfSortedTables(&objects), 59 schema_builder.CreateVectorOfSortedTables(&enums), 60 /*(unused) file_ident=*/0, 61 /*(unused) file_ext=*/0, 62 /*root_table*/ objects[0])); 63 return std::string( 64 reinterpret_cast<const char*>(schema_builder.GetBufferPointer()), 65 schema_builder.GetSize()); 66 } 67 68 TEST(LuaRankingTest, PassThrough) { 69 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}}; 70 ActionsSuggestionsResponse response; 71 response.actions = { 72 {/*response_text=*/"hello there", /*type=*/"text_reply", 73 /*score=*/1.0}, 74 {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5}, 75 {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}}; 76 const std::string test_snippet = R"( 77 local result = {} 78 for i=1,#actions do 79 table.insert(result, i) 80 end 81 return result 82 )"; 83 84 EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create( 85 conversation, test_snippet, /*entity_data_schema=*/nullptr, 86 /*annotations_entity_data_schema=*/nullptr, &response) 87 ->RankActions()); 88 EXPECT_THAT(response.actions, 89 testing::ElementsAreArray({IsActionType("text_reply"), 90 IsActionType("share_location"), 91 IsActionType("add_to_collection")})); 92 } 93 94 TEST(LuaRankingTest, Filtering) { 95 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}}; 96 ActionsSuggestionsResponse response; 97 response.actions = { 98 {/*response_text=*/"hello there", /*type=*/"text_reply", 99 /*score=*/1.0}, 100 {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5}, 101 {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}}; 102 const std::string test_snippet = R"( 103 return {} 104 )"; 105 106 EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create( 107 conversation, test_snippet, /*entity_data_schema=*/nullptr, 108 /*annotations_entity_data_schema=*/nullptr, &response) 109 ->RankActions()); 110 EXPECT_THAT(response.actions, testing::IsEmpty()); 111 } 112 113 TEST(LuaRankingTest, Duplication) { 114 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}}; 115 ActionsSuggestionsResponse response; 116 response.actions = { 117 {/*response_text=*/"hello there", /*type=*/"text_reply", 118 /*score=*/1.0}, 119 {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5}, 120 {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}}; 121 const std::string test_snippet = R"( 122 local result = {} 123 for i=1,#actions do 124 table.insert(result, 1) 125 end 126 return result 127 )"; 128 129 EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create( 130 conversation, test_snippet, /*entity_data_schema=*/nullptr, 131 /*annotations_entity_data_schema=*/nullptr, &response) 132 ->RankActions()); 133 EXPECT_THAT(response.actions, 134 testing::ElementsAreArray({IsActionType("text_reply"), 135 IsActionType("text_reply"), 136 IsActionType("text_reply")})); 137 } 138 139 TEST(LuaRankingTest, SortByScore) { 140 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}}; 141 ActionsSuggestionsResponse response; 142 response.actions = { 143 {/*response_text=*/"hello there", /*type=*/"text_reply", 144 /*score=*/1.0}, 145 {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5}, 146 {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}}; 147 const std::string test_snippet = R"( 148 function testScoreSorter(a, b) 149 return actions[a].score < actions[b].score 150 end 151 local result = {} 152 for i=1,#actions do 153 result[i] = i 154 end 155 table.sort(result, testScoreSorter) 156 return result 157 )"; 158 159 EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create( 160 conversation, test_snippet, /*entity_data_schema=*/nullptr, 161 /*annotations_entity_data_schema=*/nullptr, &response) 162 ->RankActions()); 163 EXPECT_THAT(response.actions, 164 testing::ElementsAreArray({IsActionType("add_to_collection"), 165 IsActionType("share_location"), 166 IsActionType("text_reply")})); 167 } 168 169 TEST(LuaRankingTest, SuppressType) { 170 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}}; 171 ActionsSuggestionsResponse response; 172 response.actions = { 173 {/*response_text=*/"hello there", /*type=*/"text_reply", 174 /*score=*/1.0}, 175 {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5}, 176 {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}}; 177 const std::string test_snippet = R"( 178 local result = {} 179 for id, action in pairs(actions) do 180 if action.type ~= "text_reply" then 181 table.insert(result, id) 182 end 183 end 184 return result 185 )"; 186 187 EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create( 188 conversation, test_snippet, /*entity_data_schema=*/nullptr, 189 /*annotations_entity_data_schema=*/nullptr, &response) 190 ->RankActions()); 191 EXPECT_THAT(response.actions, 192 testing::ElementsAreArray({IsActionType("share_location"), 193 IsActionType("add_to_collection")})); 194 } 195 196 TEST(LuaRankingTest, HandlesConversation) { 197 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}}; 198 ActionsSuggestionsResponse response; 199 response.actions = { 200 {/*response_text=*/"hello there", /*type=*/"text_reply", 201 /*score=*/1.0}, 202 {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5}, 203 {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}}; 204 const std::string test_snippet = R"( 205 local result = {} 206 if messages[1].text ~= "hello hello" then 207 return result 208 end 209 for id, action in pairs(actions) do 210 if action.type ~= "text_reply" then 211 table.insert(result, id) 212 end 213 end 214 return result 215 )"; 216 217 EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create( 218 conversation, test_snippet, /*entity_data_schema=*/nullptr, 219 /*annotations_entity_data_schema=*/nullptr, &response) 220 ->RankActions()); 221 EXPECT_THAT(response.actions, 222 testing::ElementsAreArray({IsActionType("share_location"), 223 IsActionType("add_to_collection")})); 224 } 225 226 TEST(LuaRankingTest, HandlesEntityData) { 227 std::string serialized_schema = TestEntitySchema(); 228 const reflection::Schema* entity_data_schema = 229 flatbuffers::GetRoot<reflection::Schema>(serialized_schema.data()); 230 231 // Create test entity data. 232 ReflectiveFlatbufferBuilder builder(entity_data_schema); 233 std::unique_ptr<ReflectiveFlatbuffer> buffer = builder.NewRoot(); 234 buffer->Set("test", "value_a"); 235 const std::string serialized_entity_data_a = buffer->Serialize(); 236 buffer->Set("test", "value_b"); 237 const std::string serialized_entity_data_b = buffer->Serialize(); 238 239 const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}}; 240 ActionsSuggestionsResponse response; 241 response.actions = { 242 {/*response_text=*/"", /*type=*/"test", 243 /*score=*/1.0, /*priority_score=*/1.0, /*annotations=*/{}, 244 /*serialized_entity_data=*/serialized_entity_data_a}, 245 {/*response_text=*/"", /*type=*/"test", 246 /*score=*/1.0, /*priority_score=*/1.0, /*annotations=*/{}, 247 /*serialized_entity_data=*/serialized_entity_data_b}, 248 {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5}, 249 {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}}; 250 const std::string test_snippet = R"( 251 local result = {} 252 for id, action in pairs(actions) do 253 if action.type == "test" and action.test == "value_a" then 254 table.insert(result, id) 255 end 256 end 257 return result 258 )"; 259 260 EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create( 261 conversation, test_snippet, entity_data_schema, 262 /*annotations_entity_data_schema=*/nullptr, &response) 263 ->RankActions()); 264 EXPECT_THAT(response.actions, 265 testing::ElementsAreArray({IsActionType("test")})); 266 } 267 268 } // namespace 269 } // namespace libtextclassifier3 270