Home | History | Annotate | Download | only in actions
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "actions/lua-ranker.h"
     18 
     19 #include <string>
     20 
     21 #include "actions/types.h"
     22 #include "utils/flatbuffers.h"
     23 #include "gmock/gmock.h"
     24 #include "gtest/gtest.h"
     25 
     26 namespace libtextclassifier3 {
     27 namespace {
     28 
     29 MATCHER_P2(IsAction, type, response_text, "") {
     30   return testing::Value(arg.type, type) &&
     31          testing::Value(arg.response_text, response_text);
     32 }
     33 
     34 MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
     35 
     36 std::string TestEntitySchema() {
     37   // Create fake entity data schema meta data.
     38   // Cannot use object oriented API here as that is not available for the
     39   // reflection schema.
     40   flatbuffers::FlatBufferBuilder schema_builder;
     41   std::vector<flatbuffers::Offset<reflection::Field>> fields = {
     42       reflection::CreateField(
     43           schema_builder,
     44           /*name=*/schema_builder.CreateString("test"),
     45           /*type=*/
     46           reflection::CreateType(schema_builder,
     47                                  /*base_type=*/reflection::String),
     48           /*id=*/0,
     49           /*offset=*/4)};
     50   std::vector<flatbuffers::Offset<reflection::Enum>> enums;
     51   std::vector<flatbuffers::Offset<reflection::Object>> objects = {
     52       reflection::CreateObject(
     53           schema_builder,
     54           /*name=*/schema_builder.CreateString("EntityData"),
     55           /*fields=*/
     56           schema_builder.CreateVectorOfSortedTables(&fields))};
     57   schema_builder.Finish(reflection::CreateSchema(
     58       schema_builder, schema_builder.CreateVectorOfSortedTables(&objects),
     59       schema_builder.CreateVectorOfSortedTables(&enums),
     60       /*(unused) file_ident=*/0,
     61       /*(unused) file_ext=*/0,
     62       /*root_table*/ objects[0]));
     63   return std::string(
     64       reinterpret_cast<const char*>(schema_builder.GetBufferPointer()),
     65       schema_builder.GetSize());
     66 }
     67 
     68 TEST(LuaRankingTest, PassThrough) {
     69   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
     70   ActionsSuggestionsResponse response;
     71   response.actions = {
     72       {/*response_text=*/"hello there", /*type=*/"text_reply",
     73        /*score=*/1.0},
     74       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
     75       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
     76   const std::string test_snippet = R"(
     77     local result = {}
     78     for i=1,#actions do
     79       table.insert(result, i)
     80     end
     81     return result
     82   )";
     83 
     84   EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
     85                   conversation, test_snippet, /*entity_data_schema=*/nullptr,
     86                   /*annotations_entity_data_schema=*/nullptr, &response)
     87                   ->RankActions());
     88   EXPECT_THAT(response.actions,
     89               testing::ElementsAreArray({IsActionType("text_reply"),
     90                                          IsActionType("share_location"),
     91                                          IsActionType("add_to_collection")}));
     92 }
     93 
     94 TEST(LuaRankingTest, Filtering) {
     95   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
     96   ActionsSuggestionsResponse response;
     97   response.actions = {
     98       {/*response_text=*/"hello there", /*type=*/"text_reply",
     99        /*score=*/1.0},
    100       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
    101       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
    102   const std::string test_snippet = R"(
    103     return {}
    104   )";
    105 
    106   EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
    107                   conversation, test_snippet, /*entity_data_schema=*/nullptr,
    108                   /*annotations_entity_data_schema=*/nullptr, &response)
    109                   ->RankActions());
    110   EXPECT_THAT(response.actions, testing::IsEmpty());
    111 }
    112 
    113 TEST(LuaRankingTest, Duplication) {
    114   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
    115   ActionsSuggestionsResponse response;
    116   response.actions = {
    117       {/*response_text=*/"hello there", /*type=*/"text_reply",
    118        /*score=*/1.0},
    119       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
    120       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
    121   const std::string test_snippet = R"(
    122     local result = {}
    123     for i=1,#actions do
    124       table.insert(result, 1)
    125     end
    126     return result
    127   )";
    128 
    129   EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
    130                   conversation, test_snippet, /*entity_data_schema=*/nullptr,
    131                   /*annotations_entity_data_schema=*/nullptr, &response)
    132                   ->RankActions());
    133   EXPECT_THAT(response.actions,
    134               testing::ElementsAreArray({IsActionType("text_reply"),
    135                                          IsActionType("text_reply"),
    136                                          IsActionType("text_reply")}));
    137 }
    138 
    139 TEST(LuaRankingTest, SortByScore) {
    140   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
    141   ActionsSuggestionsResponse response;
    142   response.actions = {
    143       {/*response_text=*/"hello there", /*type=*/"text_reply",
    144        /*score=*/1.0},
    145       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
    146       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
    147   const std::string test_snippet = R"(
    148     function testScoreSorter(a, b)
    149       return actions[a].score < actions[b].score
    150     end
    151     local result = {}
    152     for i=1,#actions do
    153       result[i] = i
    154     end
    155     table.sort(result, testScoreSorter)
    156     return result
    157   )";
    158 
    159   EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
    160                   conversation, test_snippet, /*entity_data_schema=*/nullptr,
    161                   /*annotations_entity_data_schema=*/nullptr, &response)
    162                   ->RankActions());
    163   EXPECT_THAT(response.actions,
    164               testing::ElementsAreArray({IsActionType("add_to_collection"),
    165                                          IsActionType("share_location"),
    166                                          IsActionType("text_reply")}));
    167 }
    168 
    169 TEST(LuaRankingTest, SuppressType) {
    170   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
    171   ActionsSuggestionsResponse response;
    172   response.actions = {
    173       {/*response_text=*/"hello there", /*type=*/"text_reply",
    174        /*score=*/1.0},
    175       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
    176       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
    177   const std::string test_snippet = R"(
    178     local result = {}
    179     for id, action in pairs(actions) do
    180       if action.type ~= "text_reply" then
    181         table.insert(result, id)
    182       end
    183     end
    184     return result
    185   )";
    186 
    187   EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
    188                   conversation, test_snippet, /*entity_data_schema=*/nullptr,
    189                   /*annotations_entity_data_schema=*/nullptr, &response)
    190                   ->RankActions());
    191   EXPECT_THAT(response.actions,
    192               testing::ElementsAreArray({IsActionType("share_location"),
    193                                          IsActionType("add_to_collection")}));
    194 }
    195 
    196 TEST(LuaRankingTest, HandlesConversation) {
    197   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
    198   ActionsSuggestionsResponse response;
    199   response.actions = {
    200       {/*response_text=*/"hello there", /*type=*/"text_reply",
    201        /*score=*/1.0},
    202       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
    203       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
    204   const std::string test_snippet = R"(
    205     local result = {}
    206     if messages[1].text ~= "hello hello" then
    207       return result
    208     end
    209     for id, action in pairs(actions) do
    210       if action.type ~= "text_reply" then
    211         table.insert(result, id)
    212       end
    213     end
    214     return result
    215   )";
    216 
    217   EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
    218                   conversation, test_snippet, /*entity_data_schema=*/nullptr,
    219                   /*annotations_entity_data_schema=*/nullptr, &response)
    220                   ->RankActions());
    221   EXPECT_THAT(response.actions,
    222               testing::ElementsAreArray({IsActionType("share_location"),
    223                                          IsActionType("add_to_collection")}));
    224 }
    225 
    226 TEST(LuaRankingTest, HandlesEntityData) {
    227   std::string serialized_schema = TestEntitySchema();
    228   const reflection::Schema* entity_data_schema =
    229       flatbuffers::GetRoot<reflection::Schema>(serialized_schema.data());
    230 
    231   // Create test entity data.
    232   ReflectiveFlatbufferBuilder builder(entity_data_schema);
    233   std::unique_ptr<ReflectiveFlatbuffer> buffer = builder.NewRoot();
    234   buffer->Set("test", "value_a");
    235   const std::string serialized_entity_data_a = buffer->Serialize();
    236   buffer->Set("test", "value_b");
    237   const std::string serialized_entity_data_b = buffer->Serialize();
    238 
    239   const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
    240   ActionsSuggestionsResponse response;
    241   response.actions = {
    242       {/*response_text=*/"", /*type=*/"test",
    243        /*score=*/1.0, /*priority_score=*/1.0, /*annotations=*/{},
    244        /*serialized_entity_data=*/serialized_entity_data_a},
    245       {/*response_text=*/"", /*type=*/"test",
    246        /*score=*/1.0, /*priority_score=*/1.0, /*annotations=*/{},
    247        /*serialized_entity_data=*/serialized_entity_data_b},
    248       {/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
    249       {/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
    250   const std::string test_snippet = R"(
    251     local result = {}
    252     for id, action in pairs(actions) do
    253       if action.type == "test" and action.test == "value_a" then
    254         table.insert(result, id)
    255       end
    256     end
    257     return result
    258   )";
    259 
    260   EXPECT_TRUE(ActionsSuggestionsLuaRanker::Create(
    261                   conversation, test_snippet, entity_data_schema,
    262                   /*annotations_entity_data_schema=*/nullptr, &response)
    263                   ->RankActions());
    264   EXPECT_THAT(response.actions,
    265               testing::ElementsAreArray({IsActionType("test")}));
    266 }
    267 
    268 }  // namespace
    269 }  // namespace libtextclassifier3
    270