1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "common/embedding-feature-extractor.h" 18 19 #include "lang_id/language-identifier-features.h" 20 #include "lang_id/light-sentence-features.h" 21 #include "lang_id/light-sentence.h" 22 #include "lang_id/relevant-script-feature.h" 23 #include "gtest/gtest.h" 24 25 namespace libtextclassifier { 26 namespace nlp_core { 27 28 class EmbeddingFeatureExtractorTest : public ::testing::Test { 29 public: 30 void SetUp() override { 31 // Make sure all relevant features are registered: 32 lang_id::ContinuousBagOfNgramsFunction::RegisterClass(); 33 lang_id::RelevantScriptFeature::RegisterClass(); 34 } 35 }; 36 37 // Specialization of EmbeddingFeatureExtractor that extracts from LightSentence. 38 class TestEmbeddingFeatureExtractor 39 : public EmbeddingFeatureExtractor<lang_id::LightSentenceExtractor, 40 lang_id::LightSentence> { 41 public: 42 const std::string ArgPrefix() const override { return "test"; } 43 }; 44 45 TEST_F(EmbeddingFeatureExtractorTest, NoEmbeddingSpaces) { 46 TaskContext context; 47 context.SetParameter("test_features", ""); 48 context.SetParameter("test_embedding_names", ""); 49 context.SetParameter("test_embedding_dims", ""); 50 TestEmbeddingFeatureExtractor tefe; 51 ASSERT_TRUE(tefe.Init(&context)); 52 EXPECT_EQ(tefe.NumEmbeddings(), 0); 53 } 54 55 TEST_F(EmbeddingFeatureExtractorTest, GoodSpec) { 56 TaskContext context; 57 const std::string spec = 58 "continuous-bag-of-ngrams(id_dim=5000,size=3);" 59 "continuous-bag-of-ngrams(id_dim=7000,size=4)"; 60 context.SetParameter("test_features", spec); 61 context.SetParameter("test_embedding_names", "trigram;quadgram"); 62 context.SetParameter("test_embedding_dims", "16;24"); 63 TestEmbeddingFeatureExtractor tefe; 64 ASSERT_TRUE(tefe.Init(&context)); 65 EXPECT_EQ(tefe.NumEmbeddings(), 2); 66 EXPECT_EQ(tefe.EmbeddingSize(0), 5000); 67 EXPECT_EQ(tefe.EmbeddingDims(0), 16); 68 EXPECT_EQ(tefe.EmbeddingSize(1), 7000); 69 EXPECT_EQ(tefe.EmbeddingDims(1), 24); 70 } 71 72 TEST_F(EmbeddingFeatureExtractorTest, MissmatchFmlVsNames) { 73 TaskContext context; 74 const std::string spec = 75 "continuous-bag-of-ngrams(id_dim=5000,size=3);" 76 "continuous-bag-of-ngrams(id_dim=7000,size=4)"; 77 context.SetParameter("test_features", spec); 78 context.SetParameter("test_embedding_names", "trigram"); 79 context.SetParameter("test_embedding_dims", "16;16"); 80 TestEmbeddingFeatureExtractor tefe; 81 ASSERT_FALSE(tefe.Init(&context)); 82 } 83 84 TEST_F(EmbeddingFeatureExtractorTest, MissmatchFmlVsDims) { 85 TaskContext context; 86 const std::string spec = 87 "continuous-bag-of-ngrams(id_dim=5000,size=3);" 88 "continuous-bag-of-ngrams(id_dim=7000,size=4)"; 89 context.SetParameter("test_features", spec); 90 context.SetParameter("test_embedding_names", "trigram;quadgram"); 91 context.SetParameter("test_embedding_dims", "16;16;32"); 92 TestEmbeddingFeatureExtractor tefe; 93 ASSERT_FALSE(tefe.Init(&context)); 94 } 95 96 TEST_F(EmbeddingFeatureExtractorTest, BrokenSpec) { 97 TaskContext context; 98 const std::string spec = 99 "continuous-bag-of-ngrams(id_dim=5000;" 100 "continuous-bag-of-ngrams(id_dim=7000,size=4)"; 101 context.SetParameter("test_features", spec); 102 context.SetParameter("test_embedding_names", "trigram;quadgram"); 103 context.SetParameter("test_embedding_dims", "16;16"); 104 TestEmbeddingFeatureExtractor tefe; 105 ASSERT_FALSE(tefe.Init(&context)); 106 } 107 108 TEST_F(EmbeddingFeatureExtractorTest, MissingFeature) { 109 TaskContext context; 110 const std::string spec = 111 "continuous-bag-of-ngrams(id_dim=5000,size=3);" 112 "no-such-feature"; 113 context.SetParameter("test_features", spec); 114 context.SetParameter("test_embedding_names", "trigram;foo"); 115 context.SetParameter("test_embedding_dims", "16;16"); 116 TestEmbeddingFeatureExtractor tefe; 117 ASSERT_FALSE(tefe.Init(&context)); 118 } 119 120 TEST_F(EmbeddingFeatureExtractorTest, MultipleFeatures) { 121 TaskContext context; 122 const std::string spec = 123 "continuous-bag-of-ngrams(id_dim=1000,size=3);" 124 "continuous-bag-of-relevant-scripts"; 125 context.SetParameter("test_features", spec); 126 context.SetParameter("test_embedding_names", "trigram;script"); 127 context.SetParameter("test_embedding_dims", "8;16"); 128 TestEmbeddingFeatureExtractor tefe; 129 ASSERT_TRUE(tefe.Init(&context)); 130 EXPECT_EQ(tefe.NumEmbeddings(), 2); 131 EXPECT_EQ(tefe.EmbeddingSize(0), 1000); 132 EXPECT_EQ(tefe.EmbeddingDims(0), 8); 133 134 // continuous-bag-of-relevant-scripts has its own hard-wired vocabulary size. 135 // We don't want this test to depend on that value; we just check it's bigger 136 // than 0. 137 EXPECT_GT(tefe.EmbeddingSize(1), 0); 138 EXPECT_EQ(tefe.EmbeddingDims(1), 16); 139 } 140 141 } // namespace nlp_core 142 } // namespace libtextclassifier 143