Home | History | Annotate | Download | only in actions
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "actions/lua-actions.h"
     18 
     19 #include <map>
     20 #include <string>
     21 
     22 #include "actions/test_utils.h"
     23 #include "actions/types.h"
     24 #include "utils/tflite-model-executor.h"
     25 #include "gmock/gmock.h"
     26 #include "gtest/gtest.h"
     27 
     28 namespace libtextclassifier3 {
     29 namespace {
     30 
     31 MATCHER_P2(IsAction, type, response_text, "") {
     32   return testing::Value(arg.type, type) &&
     33          testing::Value(arg.response_text, response_text);
     34 }
     35 
     36 MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
     37 
     38 TEST(LuaActions, SimpleAction) {
     39   Conversation conversation;
     40   const std::string test_snippet = R"(
     41     return {{ type = "test_action" }}
     42   )";
     43   std::vector<ActionSuggestion> actions;
     44   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
     45                   test_snippet, conversation,
     46                   /*model_executor=*/nullptr,
     47                   /*model_spec=*/nullptr,
     48                   /*interpreter=*/nullptr,
     49                   /*actions_entity_data_schema=*/nullptr,
     50                   /*annotations_entity_data_schema=*/nullptr)
     51                   ->SuggestActions(&actions));
     52   EXPECT_THAT(actions,
     53               testing::ElementsAreArray({IsActionType("test_action")}));
     54 }
     55 
     56 TEST(LuaActions, ConversationActions) {
     57   Conversation conversation;
     58   conversation.messages.push_back({/*user_id=*/0, "hello there!"});
     59   conversation.messages.push_back({/*user_id=*/1, "general kenobi!"});
     60   const std::string test_snippet = R"(
     61     local actions = {}
     62     for i, message in pairs(messages) do
     63       if i < #messages then
     64         if message.text == "hello there!" and
     65            messages[i+1].text == "general kenobi!" then
     66            table.insert(actions, {
     67              type = "text_reply",
     68              response_text = "you are a bold one!"
     69            })
     70         end
     71         if message.text == "i am the senate!" and
     72            messages[i+1].text == "not yet!" then
     73            table.insert(actions, {
     74              type = "text_reply",
     75              response_text = "it's treason then"
     76            })
     77         end
     78       end
     79     end
     80     return actions;
     81   )";
     82   std::vector<ActionSuggestion> actions;
     83   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
     84                   test_snippet, conversation,
     85                   /*model_executor=*/nullptr,
     86                   /*model_spec=*/nullptr,
     87                   /*interpreter=*/nullptr,
     88                   /*actions_entity_data_schema=*/nullptr,
     89                   /*annotations_entity_data_schema=*/nullptr)
     90                   ->SuggestActions(&actions));
     91   EXPECT_THAT(actions, testing::ElementsAreArray(
     92                            {IsAction("text_reply", "you are a bold one!")}));
     93 }
     94 
     95 TEST(LuaActions, SimpleModelAction) {
     96   Conversation conversation;
     97   const std::string test_snippet = R"(
     98     if #model.actions_scores == 0 then
     99       return {{ type = "test_action" }}
    100     end
    101     return {}
    102   )";
    103   std::vector<ActionSuggestion> actions;
    104   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
    105                   test_snippet, conversation,
    106                   /*model_executor=*/nullptr,
    107                   /*model_spec=*/nullptr,
    108                   /*interpreter=*/nullptr,
    109                   /*actions_entity_data_schema=*/nullptr,
    110                   /*annotations_entity_data_schema=*/nullptr)
    111                   ->SuggestActions(&actions));
    112   EXPECT_THAT(actions,
    113               testing::ElementsAreArray({IsActionType("test_action")}));
    114 }
    115 
    116 TEST(LuaActions, AnnotationActions) {
    117   AnnotatedSpan annotation;
    118   annotation.span = {11, 15};
    119   annotation.classification = {ClassificationResult("address", 1.0)};
    120   Conversation conversation = {{{/*user_id=*/1, "are you at home?",
    121                                  /*reference_time_ms_utc=*/0,
    122                                  /*reference_timezone=*/"Europe/Zurich",
    123                                  /*annotations=*/{annotation},
    124                                  /*locales=*/"en"}}};
    125   const std::string test_snippet = R"(
    126     local actions = {}
    127     local last_message = messages[#messages]
    128     for i, annotation in pairs(last_message.annotation) do
    129       if #annotation.classification > 0 then
    130         if annotation.classification[1].collection == "address" then
    131            local text = string.sub(last_message.text,
    132                             annotation.span["begin"] + 1,
    133                             annotation.span["end"])
    134            table.insert(actions, {
    135              type = "text_reply",
    136              response_text = "i am at " .. text,
    137              annotation = {{
    138                name = "location",
    139                span = {
    140                  text = text
    141                },
    142                entity = annotation.classification[1]
    143              }},
    144            })
    145         end
    146       end
    147     end
    148     return actions;
    149   )";
    150   std::vector<ActionSuggestion> actions;
    151   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
    152                   test_snippet, conversation,
    153                   /*model_executor=*/nullptr,
    154                   /*model_spec=*/nullptr,
    155                   /*interpreter=*/nullptr,
    156                   /*actions_entity_data_schema=*/nullptr,
    157                   /*annotations_entity_data_schema=*/nullptr)
    158                   ->SuggestActions(&actions));
    159   EXPECT_THAT(actions, testing::ElementsAreArray(
    160                            {IsAction("text_reply", "i am at home")}));
    161   EXPECT_EQ("address", actions[0].annotations[0].entity.collection);
    162 }
    163 
    164 TEST(LuaActions, EntityData) {
    165   std::string test_schema = TestEntityDataSchema();
    166   Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
    167   const std::string test_snippet = R"(
    168     return {{
    169       type = "test",
    170       entity = {
    171         greeting = "hello",
    172         location = "there",
    173         person = "Kenobi",
    174       },
    175     }};
    176   )";
    177   std::vector<ActionSuggestion> actions;
    178   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
    179                   test_snippet, conversation,
    180                   /*model_executor=*/nullptr,
    181                   /*model_spec=*/nullptr,
    182                   /*interpreter=*/nullptr,
    183                   /*actions_entity_data_schema=*/
    184                   flatbuffers::GetRoot<reflection::Schema>(test_schema.data()),
    185                   /*annotations_entity_data_schema=*/nullptr)
    186                   ->SuggestActions(&actions));
    187   EXPECT_THAT(actions, testing::SizeIs(1));
    188   EXPECT_EQ("test", actions.front().type);
    189   const flatbuffers::Table* entity =
    190       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
    191           actions.front().serialized_entity_data.data()));
    192   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
    193             "hello");
    194   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
    195             "there");
    196   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
    197             "Kenobi");
    198 }
    199 
    200 }  // namespace
    201 }  // namespace libtextclassifier3
    202