Home | History | Annotate | Download | only in tests
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "common/embedding-feature-extractor.h"
     18 
     19 #include "lang_id/language-identifier-features.h"
     20 #include "lang_id/light-sentence-features.h"
     21 #include "lang_id/light-sentence.h"
     22 #include "lang_id/relevant-script-feature.h"
     23 #include "gtest/gtest.h"
     24 
     25 namespace libtextclassifier {
     26 namespace nlp_core {
     27 
     28 class EmbeddingFeatureExtractorTest : public ::testing::Test {
     29  public:
     30   void SetUp() override {
     31     // Make sure all relevant features are registered:
     32     lang_id::ContinuousBagOfNgramsFunction::RegisterClass();
     33     lang_id::RelevantScriptFeature::RegisterClass();
     34   }
     35 };
     36 
     37 // Specialization of EmbeddingFeatureExtractor that extracts from LightSentence.
     38 class TestEmbeddingFeatureExtractor
     39     : public EmbeddingFeatureExtractor<lang_id::LightSentenceExtractor,
     40                                        lang_id::LightSentence> {
     41  public:
     42   const std::string ArgPrefix() const override { return "test"; }
     43 };
     44 
     45 TEST_F(EmbeddingFeatureExtractorTest, NoEmbeddingSpaces) {
     46   TaskContext context;
     47   context.SetParameter("test_features", "");
     48   context.SetParameter("test_embedding_names", "");
     49   context.SetParameter("test_embedding_dims", "");
     50   TestEmbeddingFeatureExtractor tefe;
     51   ASSERT_TRUE(tefe.Init(&context));
     52   EXPECT_EQ(tefe.NumEmbeddings(), 0);
     53 }
     54 
     55 TEST_F(EmbeddingFeatureExtractorTest, GoodSpec) {
     56   TaskContext context;
     57   const std::string spec =
     58       "continuous-bag-of-ngrams(id_dim=5000,size=3);"
     59       "continuous-bag-of-ngrams(id_dim=7000,size=4)";
     60   context.SetParameter("test_features", spec);
     61   context.SetParameter("test_embedding_names", "trigram;quadgram");
     62   context.SetParameter("test_embedding_dims", "16;24");
     63   TestEmbeddingFeatureExtractor tefe;
     64   ASSERT_TRUE(tefe.Init(&context));
     65   EXPECT_EQ(tefe.NumEmbeddings(), 2);
     66   EXPECT_EQ(tefe.EmbeddingSize(0), 5000);
     67   EXPECT_EQ(tefe.EmbeddingDims(0), 16);
     68   EXPECT_EQ(tefe.EmbeddingSize(1), 7000);
     69   EXPECT_EQ(tefe.EmbeddingDims(1), 24);
     70 }
     71 
     72 TEST_F(EmbeddingFeatureExtractorTest, MissmatchFmlVsNames) {
     73   TaskContext context;
     74   const std::string spec =
     75       "continuous-bag-of-ngrams(id_dim=5000,size=3);"
     76       "continuous-bag-of-ngrams(id_dim=7000,size=4)";
     77   context.SetParameter("test_features", spec);
     78   context.SetParameter("test_embedding_names", "trigram");
     79   context.SetParameter("test_embedding_dims", "16;16");
     80   TestEmbeddingFeatureExtractor tefe;
     81   ASSERT_FALSE(tefe.Init(&context));
     82 }
     83 
     84 TEST_F(EmbeddingFeatureExtractorTest, MissmatchFmlVsDims) {
     85   TaskContext context;
     86   const std::string spec =
     87       "continuous-bag-of-ngrams(id_dim=5000,size=3);"
     88       "continuous-bag-of-ngrams(id_dim=7000,size=4)";
     89   context.SetParameter("test_features", spec);
     90   context.SetParameter("test_embedding_names", "trigram;quadgram");
     91   context.SetParameter("test_embedding_dims", "16;16;32");
     92   TestEmbeddingFeatureExtractor tefe;
     93   ASSERT_FALSE(tefe.Init(&context));
     94 }
     95 
     96 TEST_F(EmbeddingFeatureExtractorTest, BrokenSpec) {
     97   TaskContext context;
     98   const std::string spec =
     99       "continuous-bag-of-ngrams(id_dim=5000;"
    100       "continuous-bag-of-ngrams(id_dim=7000,size=4)";
    101   context.SetParameter("test_features", spec);
    102   context.SetParameter("test_embedding_names", "trigram;quadgram");
    103   context.SetParameter("test_embedding_dims", "16;16");
    104   TestEmbeddingFeatureExtractor tefe;
    105   ASSERT_FALSE(tefe.Init(&context));
    106 }
    107 
    108 TEST_F(EmbeddingFeatureExtractorTest, MissingFeature) {
    109   TaskContext context;
    110   const std::string spec =
    111       "continuous-bag-of-ngrams(id_dim=5000,size=3);"
    112       "no-such-feature";
    113   context.SetParameter("test_features", spec);
    114   context.SetParameter("test_embedding_names", "trigram;foo");
    115   context.SetParameter("test_embedding_dims", "16;16");
    116   TestEmbeddingFeatureExtractor tefe;
    117   ASSERT_FALSE(tefe.Init(&context));
    118 }
    119 
    120 TEST_F(EmbeddingFeatureExtractorTest, MultipleFeatures) {
    121   TaskContext context;
    122   const std::string spec =
    123       "continuous-bag-of-ngrams(id_dim=1000,size=3);"
    124       "continuous-bag-of-relevant-scripts";
    125   context.SetParameter("test_features", spec);
    126   context.SetParameter("test_embedding_names", "trigram;script");
    127   context.SetParameter("test_embedding_dims", "8;16");
    128   TestEmbeddingFeatureExtractor tefe;
    129   ASSERT_TRUE(tefe.Init(&context));
    130   EXPECT_EQ(tefe.NumEmbeddings(), 2);
    131   EXPECT_EQ(tefe.EmbeddingSize(0), 1000);
    132   EXPECT_EQ(tefe.EmbeddingDims(0), 8);
    133 
    134   // continuous-bag-of-relevant-scripts has its own hard-wired vocabulary size.
    135   // We don't want this test to depend on that value; we just check it's bigger
    136   // than 0.
    137   EXPECT_GT(tefe.EmbeddingSize(1), 0);
    138   EXPECT_EQ(tefe.EmbeddingDims(1), 16);
    139 }
    140 
    141 }  // namespace nlp_core
    142 }  // namespace libtextclassifier
    143