Home | History | Annotate | Download | only in libtextclassifier
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "token-feature-extractor.h"
     18 
     19 #include "gmock/gmock.h"
     20 #include "gtest/gtest.h"
     21 
     22 namespace libtextclassifier2 {
     23 namespace {
     24 
     25 class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
     26  public:
     27   using TokenFeatureExtractor::HashToken;
     28   using TokenFeatureExtractor::TokenFeatureExtractor;
     29 };
     30 
     31 TEST(TokenFeatureExtractorTest, ExtractAscii) {
     32   TokenFeatureExtractorOptions options;
     33   options.num_buckets = 1000;
     34   options.chargram_orders = std::vector<int>{1, 2, 3};
     35   options.extract_case_feature = true;
     36   options.unicode_aware_features = false;
     37   options.extract_selection_mask_feature = true;
     38   CREATE_UNILIB_FOR_TESTING
     39   TestingTokenFeatureExtractor extractor(options, unilib);
     40 
     41   std::vector<int> sparse_features;
     42   std::vector<float> dense_features;
     43 
     44   extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
     45                     &dense_features);
     46 
     47   EXPECT_THAT(sparse_features,
     48               testing::ElementsAreArray({
     49                   // clang-format off
     50                   extractor.HashToken("H"),
     51                   extractor.HashToken("e"),
     52                   extractor.HashToken("l"),
     53                   extractor.HashToken("l"),
     54                   extractor.HashToken("o"),
     55                   extractor.HashToken("^H"),
     56                   extractor.HashToken("He"),
     57                   extractor.HashToken("el"),
     58                   extractor.HashToken("ll"),
     59                   extractor.HashToken("lo"),
     60                   extractor.HashToken("o$"),
     61                   extractor.HashToken("^He"),
     62                   extractor.HashToken("Hel"),
     63                   extractor.HashToken("ell"),
     64                   extractor.HashToken("llo"),
     65                   extractor.HashToken("lo$")
     66                   // clang-format on
     67               }));
     68   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
     69 
     70   sparse_features.clear();
     71   dense_features.clear();
     72   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
     73                     &dense_features);
     74 
     75   EXPECT_THAT(sparse_features,
     76               testing::ElementsAreArray({
     77                   // clang-format off
     78                   extractor.HashToken("w"),
     79                   extractor.HashToken("o"),
     80                   extractor.HashToken("r"),
     81                   extractor.HashToken("l"),
     82                   extractor.HashToken("d"),
     83                   extractor.HashToken("!"),
     84                   extractor.HashToken("^w"),
     85                   extractor.HashToken("wo"),
     86                   extractor.HashToken("or"),
     87                   extractor.HashToken("rl"),
     88                   extractor.HashToken("ld"),
     89                   extractor.HashToken("d!"),
     90                   extractor.HashToken("!$"),
     91                   extractor.HashToken("^wo"),
     92                   extractor.HashToken("wor"),
     93                   extractor.HashToken("orl"),
     94                   extractor.HashToken("rld"),
     95                   extractor.HashToken("ld!"),
     96                   extractor.HashToken("d!$"),
     97                   // clang-format on
     98               }));
     99   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
    100 }
    101 
    102 TEST(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) {
    103   TokenFeatureExtractorOptions options;
    104   options.num_buckets = 1000;
    105   options.chargram_orders = std::vector<int>{};
    106   options.extract_case_feature = true;
    107   options.unicode_aware_features = false;
    108   options.extract_selection_mask_feature = true;
    109   CREATE_UNILIB_FOR_TESTING
    110   TestingTokenFeatureExtractor extractor(options, unilib);
    111 
    112   std::vector<int> sparse_features;
    113   std::vector<float> dense_features;
    114 
    115   extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
    116                     &dense_features);
    117 
    118   EXPECT_THAT(sparse_features,
    119               testing::ElementsAreArray({extractor.HashToken("^Hello$")}));
    120   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
    121 
    122   sparse_features.clear();
    123   dense_features.clear();
    124   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
    125                     &dense_features);
    126 
    127   EXPECT_THAT(sparse_features,
    128               testing::ElementsAreArray({extractor.HashToken("^world!$")}));
    129   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
    130 }
    131 
    132 TEST(TokenFeatureExtractorTest, ExtractUnicode) {
    133   TokenFeatureExtractorOptions options;
    134   options.num_buckets = 1000;
    135   options.chargram_orders = std::vector<int>{1, 2, 3};
    136   options.extract_case_feature = true;
    137   options.unicode_aware_features = true;
    138   options.extract_selection_mask_feature = true;
    139   CREATE_UNILIB_FOR_TESTING
    140   TestingTokenFeatureExtractor extractor(options, unilib);
    141 
    142   std::vector<int> sparse_features;
    143   std::vector<float> dense_features;
    144 
    145   extractor.Extract(Token{"Hll", 0, 5}, true, &sparse_features,
    146                     &dense_features);
    147 
    148   EXPECT_THAT(sparse_features,
    149               testing::ElementsAreArray({
    150                   // clang-format off
    151                   extractor.HashToken("H"),
    152                   extractor.HashToken(""),
    153                   extractor.HashToken("l"),
    154                   extractor.HashToken("l"),
    155                   extractor.HashToken(""),
    156                   extractor.HashToken("^H"),
    157                   extractor.HashToken("H"),
    158                   extractor.HashToken("l"),
    159                   extractor.HashToken("ll"),
    160                   extractor.HashToken("l"),
    161                   extractor.HashToken("$"),
    162                   extractor.HashToken("^H"),
    163                   extractor.HashToken("Hl"),
    164                   extractor.HashToken("ll"),
    165                   extractor.HashToken("ll"),
    166                   extractor.HashToken("l$")
    167                   // clang-format on
    168               }));
    169   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
    170 
    171   sparse_features.clear();
    172   dense_features.clear();
    173   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
    174                     &dense_features);
    175 
    176   EXPECT_THAT(sparse_features,
    177               testing::ElementsAreArray({
    178                   // clang-format off
    179                   extractor.HashToken("w"),
    180                   extractor.HashToken("o"),
    181                   extractor.HashToken("r"),
    182                   extractor.HashToken("l"),
    183                   extractor.HashToken("d"),
    184                   extractor.HashToken("!"),
    185                   extractor.HashToken("^w"),
    186                   extractor.HashToken("wo"),
    187                   extractor.HashToken("or"),
    188                   extractor.HashToken("rl"),
    189                   extractor.HashToken("ld"),
    190                   extractor.HashToken("d!"),
    191                   extractor.HashToken("!$"),
    192                   extractor.HashToken("^wo"),
    193                   extractor.HashToken("wor"),
    194                   extractor.HashToken("orl"),
    195                   extractor.HashToken("rld"),
    196                   extractor.HashToken("ld!"),
    197                   extractor.HashToken("d!$"),
    198                   // clang-format on
    199               }));
    200   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
    201 }
    202 
    203 TEST(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) {
    204   TokenFeatureExtractorOptions options;
    205   options.num_buckets = 1000;
    206   options.chargram_orders = std::vector<int>{};
    207   options.extract_case_feature = true;
    208   options.unicode_aware_features = true;
    209   options.extract_selection_mask_feature = true;
    210   CREATE_UNILIB_FOR_TESTING
    211   TestingTokenFeatureExtractor extractor(options, unilib);
    212 
    213   std::vector<int> sparse_features;
    214   std::vector<float> dense_features;
    215 
    216   extractor.Extract(Token{"Hll", 0, 5}, true, &sparse_features,
    217                     &dense_features);
    218 
    219   EXPECT_THAT(sparse_features,
    220               testing::ElementsAreArray({extractor.HashToken("^Hll$")}));
    221   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
    222 
    223   sparse_features.clear();
    224   dense_features.clear();
    225   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
    226                     &dense_features);
    227 
    228   EXPECT_THAT(sparse_features, testing::ElementsAreArray({
    229                                    extractor.HashToken("^world!$"),
    230                                }));
    231   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
    232 }
    233 
    234 #ifdef LIBTEXTCLASSIFIER_TEST_ICU
    235 TEST(TokenFeatureExtractorTest, ICUCaseFeature) {
    236   TokenFeatureExtractorOptions options;
    237   options.num_buckets = 1000;
    238   options.chargram_orders = std::vector<int>{1, 2};
    239   options.extract_case_feature = true;
    240   options.unicode_aware_features = true;
    241   options.extract_selection_mask_feature = false;
    242   CREATE_UNILIB_FOR_TESTING
    243   TestingTokenFeatureExtractor extractor(options, unilib);
    244 
    245   std::vector<int> sparse_features;
    246   std::vector<float> dense_features;
    247   extractor.Extract(Token{"Hll", 0, 5}, true, &sparse_features,
    248                     &dense_features);
    249   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
    250 
    251   sparse_features.clear();
    252   dense_features.clear();
    253   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
    254                     &dense_features);
    255   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
    256 
    257   sparse_features.clear();
    258   dense_features.clear();
    259   extractor.Extract(Token{"", 23, 29}, false, &sparse_features,
    260                     &dense_features);
    261   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
    262 
    263   sparse_features.clear();
    264   dense_features.clear();
    265   extractor.Extract(Token{"", 23, 29}, false, &sparse_features,
    266                     &dense_features);
    267   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
    268 }
    269 #endif
    270 
    271 TEST(TokenFeatureExtractorTest, DigitRemapping) {
    272   TokenFeatureExtractorOptions options;
    273   options.num_buckets = 1000;
    274   options.chargram_orders = std::vector<int>{1, 2};
    275   options.remap_digits = true;
    276   options.unicode_aware_features = false;
    277   CREATE_UNILIB_FOR_TESTING
    278   TestingTokenFeatureExtractor extractor(options, unilib);
    279 
    280   std::vector<int> sparse_features;
    281   std::vector<float> dense_features;
    282   extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
    283                     &dense_features);
    284 
    285   std::vector<int> sparse_features2;
    286   extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
    287                     &dense_features);
    288   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
    289 
    290   extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
    291                     &dense_features);
    292   EXPECT_THAT(sparse_features,
    293               testing::Not(testing::ElementsAreArray(sparse_features2)));
    294 }
    295 
    296 TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) {
    297   TokenFeatureExtractorOptions options;
    298   options.num_buckets = 1000;
    299   options.chargram_orders = std::vector<int>{1, 2};
    300   options.remap_digits = true;
    301   options.unicode_aware_features = true;
    302   CREATE_UNILIB_FOR_TESTING
    303   TestingTokenFeatureExtractor extractor(options, unilib);
    304 
    305   std::vector<int> sparse_features;
    306   std::vector<float> dense_features;
    307   extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
    308                     &dense_features);
    309 
    310   std::vector<int> sparse_features2;
    311   extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
    312                     &dense_features);
    313   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
    314 
    315   extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
    316                     &dense_features);
    317   EXPECT_THAT(sparse_features,
    318               testing::Not(testing::ElementsAreArray(sparse_features2)));
    319 }
    320 
    321 TEST(TokenFeatureExtractorTest, LowercaseAscii) {
    322   TokenFeatureExtractorOptions options;
    323   options.num_buckets = 1000;
    324   options.chargram_orders = std::vector<int>{1, 2};
    325   options.lowercase_tokens = true;
    326   options.unicode_aware_features = false;
    327   CREATE_UNILIB_FOR_TESTING
    328   TestingTokenFeatureExtractor extractor(options, unilib);
    329 
    330   std::vector<int> sparse_features;
    331   std::vector<float> dense_features;
    332   extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features,
    333                     &dense_features);
    334 
    335   std::vector<int> sparse_features2;
    336   extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2,
    337                     &dense_features);
    338   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
    339 
    340   extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2,
    341                     &dense_features);
    342   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
    343 }
    344 
    345 #ifdef LIBTEXTCLASSIFIER_TEST_ICU
    346 TEST(TokenFeatureExtractorTest, LowercaseUnicode) {
    347   TokenFeatureExtractorOptions options;
    348   options.num_buckets = 1000;
    349   options.chargram_orders = std::vector<int>{1, 2};
    350   options.lowercase_tokens = true;
    351   options.unicode_aware_features = true;
    352   CREATE_UNILIB_FOR_TESTING
    353   TestingTokenFeatureExtractor extractor(options, unilib);
    354 
    355   std::vector<int> sparse_features;
    356   std::vector<float> dense_features;
    357   extractor.Extract(Token{"", 0, 6}, true, &sparse_features, &dense_features);
    358 
    359   std::vector<int> sparse_features2;
    360   extractor.Extract(Token{"", 0, 6}, true, &sparse_features2,
    361                     &dense_features);
    362   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
    363 }
    364 #endif
    365 
    366 #ifdef LIBTEXTCLASSIFIER_TEST_ICU
    367 TEST(TokenFeatureExtractorTest, RegexFeatures) {
    368   TokenFeatureExtractorOptions options;
    369   options.num_buckets = 1000;
    370   options.chargram_orders = std::vector<int>{1, 2};
    371   options.remap_digits = false;
    372   options.unicode_aware_features = false;
    373   options.regexp_features.push_back("^[a-z]+$");  // all lower case.
    374   options.regexp_features.push_back("^[0-9]+$");  // all digits.
    375   CREATE_UNILIB_FOR_TESTING
    376   TestingTokenFeatureExtractor extractor(options, unilib);
    377 
    378   std::vector<int> sparse_features;
    379   std::vector<float> dense_features;
    380   extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
    381                     &dense_features);
    382   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
    383 
    384   dense_features.clear();
    385   extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
    386                     &dense_features);
    387   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
    388 
    389   dense_features.clear();
    390   extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
    391                     &dense_features);
    392   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
    393 
    394   dense_features.clear();
    395   extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
    396                     &dense_features);
    397   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
    398 }
    399 #endif
    400 
    401 TEST(TokenFeatureExtractorTest, ExtractTooLongWord) {
    402   TokenFeatureExtractorOptions options;
    403   options.num_buckets = 1000;
    404   options.chargram_orders = std::vector<int>{22};
    405   options.extract_case_feature = true;
    406   options.unicode_aware_features = true;
    407   options.extract_selection_mask_feature = true;
    408   CREATE_UNILIB_FOR_TESTING
    409   TestingTokenFeatureExtractor extractor(options, unilib);
    410 
    411   // Test that this runs. ASAN should catch problems.
    412   std::vector<int> sparse_features;
    413   std::vector<float> dense_features;
    414   extractor.Extract(Token{"abcdefghijklmnopqstuvwxyz", 0, 0}, true,
    415                     &sparse_features, &dense_features);
    416 
    417   EXPECT_THAT(sparse_features,
    418               testing::ElementsAreArray({
    419                   // clang-format off
    420                   extractor.HashToken("^abcdefghij\1qstuvwxyz"),
    421                   extractor.HashToken("abcdefghij\1qstuvwxyz$"),
    422                   // clang-format on
    423               }));
    424 }
    425 
    426 TEST(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) {
    427   TokenFeatureExtractorOptions options;
    428   options.num_buckets = 1000;
    429   options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5};
    430   options.extract_case_feature = true;
    431   options.unicode_aware_features = true;
    432   options.extract_selection_mask_feature = true;
    433 
    434   CREATE_UNILIB_FOR_TESTING
    435   TestingTokenFeatureExtractor extractor_unicode(options, unilib);
    436 
    437   options.unicode_aware_features = false;
    438   TestingTokenFeatureExtractor extractor_ascii(options, unilib);
    439 
    440   for (const std::string& input :
    441        {"https://www.abcdefgh.com/in/xxxkkkvayio",
    442         "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil",
    443         "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd",
    444         "x", "Hello", "Hey,", "Hi", ""}) {
    445     std::vector<int> sparse_features_unicode;
    446     std::vector<float> dense_features_unicode;
    447     extractor_unicode.Extract(Token{input, 0, 0}, true,
    448                               &sparse_features_unicode,
    449                               &dense_features_unicode);
    450 
    451     std::vector<int> sparse_features_ascii;
    452     std::vector<float> dense_features_ascii;
    453     extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
    454                             &dense_features_ascii);
    455 
    456     EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
    457     EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input;
    458   }
    459 }
    460 
    461 TEST(TokenFeatureExtractorTest, ExtractForPadToken) {
    462   TokenFeatureExtractorOptions options;
    463   options.num_buckets = 1000;
    464   options.chargram_orders = std::vector<int>{1, 2};
    465   options.extract_case_feature = true;
    466   options.unicode_aware_features = false;
    467   options.extract_selection_mask_feature = true;
    468 
    469   CREATE_UNILIB_FOR_TESTING
    470   TestingTokenFeatureExtractor extractor(options, unilib);
    471 
    472   std::vector<int> sparse_features;
    473   std::vector<float> dense_features;
    474 
    475   extractor.Extract(Token(), false, &sparse_features, &dense_features);
    476 
    477   EXPECT_THAT(sparse_features,
    478               testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
    479   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
    480 }
    481 
    482 TEST(TokenFeatureExtractorTest, ExtractFiltered) {
    483   TokenFeatureExtractorOptions options;
    484   options.num_buckets = 1000;
    485   options.chargram_orders = std::vector<int>{1, 2, 3};
    486   options.extract_case_feature = true;
    487   options.unicode_aware_features = false;
    488   options.extract_selection_mask_feature = true;
    489   options.allowed_chargrams.insert("^H");
    490   options.allowed_chargrams.insert("ll");
    491   options.allowed_chargrams.insert("llo");
    492   options.allowed_chargrams.insert("w");
    493   options.allowed_chargrams.insert("!");
    494   options.allowed_chargrams.insert("\xc4");  // UTF8 control character.
    495 
    496   CREATE_UNILIB_FOR_TESTING
    497   TestingTokenFeatureExtractor extractor(options, unilib);
    498 
    499   std::vector<int> sparse_features;
    500   std::vector<float> dense_features;
    501 
    502   extractor.Extract(Token{"Hllo", 0, 5}, true, &sparse_features,
    503                     &dense_features);
    504 
    505   EXPECT_THAT(sparse_features,
    506               testing::ElementsAreArray({
    507                   // clang-format off
    508                   0,
    509                   extractor.HashToken("\xc4"),
    510                   0,
    511                   0,
    512                   0,
    513                   0,
    514                   extractor.HashToken("^H"),
    515                   0,
    516                   0,
    517                   0,
    518                   extractor.HashToken("ll"),
    519                   0,
    520                   0,
    521                   0,
    522                   0,
    523                   0,
    524                   0,
    525                   extractor.HashToken("llo"),
    526                   0
    527                   // clang-format on
    528               }));
    529   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
    530 
    531   sparse_features.clear();
    532   dense_features.clear();
    533   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
    534                     &dense_features);
    535 
    536   EXPECT_THAT(sparse_features, testing::ElementsAreArray({
    537                                    // clang-format off
    538                   extractor.HashToken("w"),
    539                   0,
    540                   0,
    541                   0,
    542                   0,
    543                   extractor.HashToken("!"),
    544                   0,
    545                   0,
    546                   0,
    547                   0,
    548                   0,
    549                   0,
    550                   0,
    551                   0,
    552                   0,
    553                   0,
    554                   0,
    555                   0,
    556                   0,
    557                                    // clang-format on
    558                                }));
    559   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
    560   EXPECT_EQ(extractor.HashToken("<PAD>"), 1);
    561 }
    562 
    563 }  // namespace
    564 }  // namespace libtextclassifier2
    565