Home | History | Annotate | Download | only in actions
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "actions/feature-processor.h"
     18 
     19 #include "actions/actions_model_generated.h"
     20 #include "annotator/model-executor.h"
     21 #include "utils/tensor-view.h"
     22 
     23 #include "gmock/gmock.h"
     24 #include "gtest/gtest.h"
     25 
     26 namespace libtextclassifier3 {
     27 namespace {
     28 
     29 using testing::FloatEq;
     30 
     31 // EmbeddingExecutor that always returns features based on
     32 // the id of the sparse features.
     33 class FakeEmbeddingExecutor : public EmbeddingExecutor {
     34  public:
     35   bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
     36                     const int dest_size) const override {
     37     TC3_CHECK_GE(dest_size, 4);
     38     EXPECT_EQ(sparse_features.size(), 1);
     39     dest[0] = sparse_features.data()[0];
     40     dest[1] = sparse_features.data()[0];
     41     dest[2] = -sparse_features.data()[0];
     42     dest[3] = -sparse_features.data()[0];
     43     return true;
     44   }
     45 
     46  private:
     47   std::vector<float> storage_;
     48 };
     49 
     50 class FeatureProcessorTest : public ::testing::Test {
     51  protected:
     52   FeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
     53 
     54   flatbuffers::DetachedBuffer PackFeatureProcessorOptions(
     55       ActionsTokenFeatureProcessorOptionsT* options) const {
     56     flatbuffers::FlatBufferBuilder builder;
     57     builder.Finish(CreateActionsTokenFeatureProcessorOptions(builder, options));
     58     return builder.Release();
     59   }
     60 
     61   FakeEmbeddingExecutor embedding_executor_;
     62   UniLib unilib_;
     63 };
     64 
     65 TEST_F(FeatureProcessorTest, TokenEmbeddings) {
     66   ActionsTokenFeatureProcessorOptionsT options;
     67   options.embedding_size = 4;
     68   options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
     69 
     70   flatbuffers::DetachedBuffer options_fb =
     71       PackFeatureProcessorOptions(&options);
     72   ActionsFeatureProcessor feature_processor(
     73       flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
     74           options_fb.data()),
     75       &unilib_);
     76 
     77   Token token("aaa", 0, 3);
     78   std::vector<float> token_features;
     79   EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
     80                                                     &token_features));
     81   EXPECT_EQ(token_features.size(), 4);
     82 }
     83 
     84 TEST_F(FeatureProcessorTest, TokenEmbeddingsCaseFeature) {
     85   ActionsTokenFeatureProcessorOptionsT options;
     86   options.embedding_size = 4;
     87   options.extract_case_feature = true;
     88   options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
     89 
     90   flatbuffers::DetachedBuffer options_fb =
     91       PackFeatureProcessorOptions(&options);
     92   ActionsFeatureProcessor feature_processor(
     93       flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
     94           options_fb.data()),
     95       &unilib_);
     96 
     97   Token token("Aaa", 0, 3);
     98   std::vector<float> token_features;
     99   EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_,
    100                                                     &token_features));
    101   EXPECT_EQ(token_features.size(), 5);
    102   EXPECT_THAT(token_features[4], FloatEq(1.0));
    103 }
    104 
    105 TEST_F(FeatureProcessorTest, MultipleTokenEmbeddingsCaseFeature) {
    106   ActionsTokenFeatureProcessorOptionsT options;
    107   options.embedding_size = 4;
    108   options.extract_case_feature = true;
    109   options.tokenizer_options.reset(new ActionsTokenizerOptionsT);
    110 
    111   flatbuffers::DetachedBuffer options_fb =
    112       PackFeatureProcessorOptions(&options);
    113   ActionsFeatureProcessor feature_processor(
    114       flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>(
    115           options_fb.data()),
    116       &unilib_);
    117 
    118   const std::vector<Token> tokens = {Token("Aaa", 0, 3), Token("bbb", 4, 7),
    119                                      Token("Cccc", 8, 12)};
    120   std::vector<float> token_features;
    121   EXPECT_TRUE(feature_processor.AppendTokenFeatures(
    122       tokens, &embedding_executor_, &token_features));
    123   EXPECT_EQ(token_features.size(), 15);
    124   EXPECT_THAT(token_features[4], FloatEq(1.0));
    125   EXPECT_THAT(token_features[9], FloatEq(-1.0));
    126   EXPECT_THAT(token_features[14], FloatEq(1.0));
    127 }
    128 
    129 }  // namespace
    130 }  // namespace libtextclassifier3
    131