1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "actions/feature-processor.h" 18 19 #include "actions/actions_model_generated.h" 20 #include "annotator/model-executor.h" 21 #include "utils/tensor-view.h" 22 23 #include "gmock/gmock.h" 24 #include "gtest/gtest.h" 25 26 namespace libtextclassifier3 { 27 namespace { 28 29 using testing::FloatEq; 30 31 // EmbeddingExecutor that always returns features based on 32 // the id of the sparse features. 33 class FakeEmbeddingExecutor : public EmbeddingExecutor { 34 public: 35 bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, 36 const int dest_size) const override { 37 TC3_CHECK_GE(dest_size, 4); 38 EXPECT_EQ(sparse_features.size(), 1); 39 dest[0] = sparse_features.data()[0]; 40 dest[1] = sparse_features.data()[0]; 41 dest[2] = -sparse_features.data()[0]; 42 dest[3] = -sparse_features.data()[0]; 43 return true; 44 } 45 46 private: 47 std::vector<float> storage_; 48 }; 49 50 class FeatureProcessorTest : public ::testing::Test { 51 protected: 52 FeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {} 53 54 flatbuffers::DetachedBuffer PackFeatureProcessorOptions( 55 ActionsTokenFeatureProcessorOptionsT* options) const { 56 flatbuffers::FlatBufferBuilder builder; 57 builder.Finish(CreateActionsTokenFeatureProcessorOptions(builder, options)); 58 return builder.Release(); 59 } 60 61 FakeEmbeddingExecutor embedding_executor_; 62 UniLib unilib_; 63 }; 64 65 TEST_F(FeatureProcessorTest, TokenEmbeddings) { 66 ActionsTokenFeatureProcessorOptionsT options; 67 options.embedding_size = 4; 68 options.tokenizer_options.reset(new ActionsTokenizerOptionsT); 69 70 flatbuffers::DetachedBuffer options_fb = 71 PackFeatureProcessorOptions(&options); 72 ActionsFeatureProcessor feature_processor( 73 flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>( 74 options_fb.data()), 75 &unilib_); 76 77 Token token("aaa", 0, 3); 78 std::vector<float> token_features; 79 EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_, 80 &token_features)); 81 EXPECT_EQ(token_features.size(), 4); 82 } 83 84 TEST_F(FeatureProcessorTest, TokenEmbeddingsCaseFeature) { 85 ActionsTokenFeatureProcessorOptionsT options; 86 options.embedding_size = 4; 87 options.extract_case_feature = true; 88 options.tokenizer_options.reset(new ActionsTokenizerOptionsT); 89 90 flatbuffers::DetachedBuffer options_fb = 91 PackFeatureProcessorOptions(&options); 92 ActionsFeatureProcessor feature_processor( 93 flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>( 94 options_fb.data()), 95 &unilib_); 96 97 Token token("Aaa", 0, 3); 98 std::vector<float> token_features; 99 EXPECT_TRUE(feature_processor.AppendTokenFeatures(token, &embedding_executor_, 100 &token_features)); 101 EXPECT_EQ(token_features.size(), 5); 102 EXPECT_THAT(token_features[4], FloatEq(1.0)); 103 } 104 105 TEST_F(FeatureProcessorTest, MultipleTokenEmbeddingsCaseFeature) { 106 ActionsTokenFeatureProcessorOptionsT options; 107 options.embedding_size = 4; 108 options.extract_case_feature = true; 109 options.tokenizer_options.reset(new ActionsTokenizerOptionsT); 110 111 flatbuffers::DetachedBuffer options_fb = 112 PackFeatureProcessorOptions(&options); 113 ActionsFeatureProcessor feature_processor( 114 flatbuffers::GetRoot<ActionsTokenFeatureProcessorOptions>( 115 options_fb.data()), 116 &unilib_); 117 118 const std::vector<Token> tokens = {Token("Aaa", 0, 3), Token("bbb", 4, 7), 119 Token("Cccc", 8, 12)}; 120 std::vector<float> token_features; 121 EXPECT_TRUE(feature_processor.AppendTokenFeatures( 122 tokens, &embedding_executor_, &token_features)); 123 EXPECT_EQ(token_features.size(), 15); 124 EXPECT_THAT(token_features[4], FloatEq(1.0)); 125 EXPECT_THAT(token_features[9], FloatEq(-1.0)); 126 EXPECT_THAT(token_features[14], FloatEq(1.0)); 127 } 128 129 } // namespace 130 } // namespace libtextclassifier3 131