Home | History | Annotate | Download | only in tests
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "smartselect/token-feature-extractor.h"
     18 
     19 #include "gmock/gmock.h"
     20 #include "gtest/gtest.h"
     21 
     22 namespace libtextclassifier {
     23 namespace {
     24 
     25 class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
     26  public:
     27   using TokenFeatureExtractor::TokenFeatureExtractor;
     28   using TokenFeatureExtractor::HashToken;
     29 };
     30 
     31 TEST(TokenFeatureExtractorTest, ExtractAscii) {
     32   TokenFeatureExtractorOptions options;
     33   options.num_buckets = 1000;
     34   options.chargram_orders = std::vector<int>{1, 2, 3};
     35   options.extract_case_feature = true;
     36   options.unicode_aware_features = false;
     37   options.extract_selection_mask_feature = true;
     38   TestingTokenFeatureExtractor extractor(options);
     39 
     40   std::vector<int> sparse_features;
     41   std::vector<float> dense_features;
     42 
     43   extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
     44                     &dense_features);
     45 
     46   EXPECT_THAT(sparse_features,
     47               testing::ElementsAreArray({
     48                   // clang-format off
     49                   extractor.HashToken("H"),
     50                   extractor.HashToken("e"),
     51                   extractor.HashToken("l"),
     52                   extractor.HashToken("l"),
     53                   extractor.HashToken("o"),
     54                   extractor.HashToken("^H"),
     55                   extractor.HashToken("He"),
     56                   extractor.HashToken("el"),
     57                   extractor.HashToken("ll"),
     58                   extractor.HashToken("lo"),
     59                   extractor.HashToken("o$"),
     60                   extractor.HashToken("^He"),
     61                   extractor.HashToken("Hel"),
     62                   extractor.HashToken("ell"),
     63                   extractor.HashToken("llo"),
     64                   extractor.HashToken("lo$")
     65                   // clang-format on
     66               }));
     67   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
     68 
     69   sparse_features.clear();
     70   dense_features.clear();
     71   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
     72                     &dense_features);
     73 
     74   EXPECT_THAT(sparse_features,
     75               testing::ElementsAreArray({
     76                   // clang-format off
     77                   extractor.HashToken("w"),
     78                   extractor.HashToken("o"),
     79                   extractor.HashToken("r"),
     80                   extractor.HashToken("l"),
     81                   extractor.HashToken("d"),
     82                   extractor.HashToken("!"),
     83                   extractor.HashToken("^w"),
     84                   extractor.HashToken("wo"),
     85                   extractor.HashToken("or"),
     86                   extractor.HashToken("rl"),
     87                   extractor.HashToken("ld"),
     88                   extractor.HashToken("d!"),
     89                   extractor.HashToken("!$"),
     90                   extractor.HashToken("^wo"),
     91                   extractor.HashToken("wor"),
     92                   extractor.HashToken("orl"),
     93                   extractor.HashToken("rld"),
     94                   extractor.HashToken("ld!"),
     95                   extractor.HashToken("d!$"),
     96                   // clang-format on
     97               }));
     98   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
     99 }
    100 
    101 TEST(TokenFeatureExtractorTest, ExtractUnicode) {
    102   TokenFeatureExtractorOptions options;
    103   options.num_buckets = 1000;
    104   options.chargram_orders = std::vector<int>{1, 2, 3};
    105   options.extract_case_feature = true;
    106   options.unicode_aware_features = true;
    107   options.extract_selection_mask_feature = true;
    108   TestingTokenFeatureExtractor extractor(options);
    109 
    110   std::vector<int> sparse_features;
    111   std::vector<float> dense_features;
    112 
    113   extractor.Extract(Token{"Hll", 0, 5}, true, &sparse_features,
    114                     &dense_features);
    115 
    116   EXPECT_THAT(sparse_features,
    117               testing::ElementsAreArray({
    118                   // clang-format off
    119                   extractor.HashToken("H"),
    120                   extractor.HashToken(""),
    121                   extractor.HashToken("l"),
    122                   extractor.HashToken("l"),
    123                   extractor.HashToken(""),
    124                   extractor.HashToken("^H"),
    125                   extractor.HashToken("H"),
    126                   extractor.HashToken("l"),
    127                   extractor.HashToken("ll"),
    128                   extractor.HashToken("l"),
    129                   extractor.HashToken("$"),
    130                   extractor.HashToken("^H"),
    131                   extractor.HashToken("Hl"),
    132                   extractor.HashToken("ll"),
    133                   extractor.HashToken("ll"),
    134                   extractor.HashToken("l$")
    135                   // clang-format on
    136               }));
    137   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
    138 
    139   sparse_features.clear();
    140   dense_features.clear();
    141   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
    142                     &dense_features);
    143 
    144   EXPECT_THAT(sparse_features,
    145               testing::ElementsAreArray({
    146                   // clang-format off
    147                   extractor.HashToken("w"),
    148                   extractor.HashToken("o"),
    149                   extractor.HashToken("r"),
    150                   extractor.HashToken("l"),
    151                   extractor.HashToken("d"),
    152                   extractor.HashToken("!"),
    153                   extractor.HashToken("^w"),
    154                   extractor.HashToken("wo"),
    155                   extractor.HashToken("or"),
    156                   extractor.HashToken("rl"),
    157                   extractor.HashToken("ld"),
    158                   extractor.HashToken("d!"),
    159                   extractor.HashToken("!$"),
    160                   extractor.HashToken("^wo"),
    161                   extractor.HashToken("wor"),
    162                   extractor.HashToken("orl"),
    163                   extractor.HashToken("rld"),
    164                   extractor.HashToken("ld!"),
    165                   extractor.HashToken("d!$"),
    166                   // clang-format on
    167               }));
    168   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
    169 }
    170 
    171 TEST(TokenFeatureExtractorTest, ICUCaseFeature) {
    172   TokenFeatureExtractorOptions options;
    173   options.num_buckets = 1000;
    174   options.chargram_orders = std::vector<int>{1, 2};
    175   options.extract_case_feature = true;
    176   options.unicode_aware_features = true;
    177   options.extract_selection_mask_feature = false;
    178   TokenFeatureExtractor extractor(options);
    179 
    180   std::vector<int> sparse_features;
    181   std::vector<float> dense_features;
    182   extractor.Extract(Token{"Hll", 0, 5}, true, &sparse_features,
    183                     &dense_features);
    184   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
    185 
    186   sparse_features.clear();
    187   dense_features.clear();
    188   extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
    189                     &dense_features);
    190   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
    191 
    192   sparse_features.clear();
    193   dense_features.clear();
    194   extractor.Extract(Token{"", 23, 29}, false, &sparse_features,
    195                     &dense_features);
    196   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
    197 
    198   sparse_features.clear();
    199   dense_features.clear();
    200   extractor.Extract(Token{"", 23, 29}, false, &sparse_features,
    201                     &dense_features);
    202   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
    203 }
    204 
    205 TEST(TokenFeatureExtractorTest, DigitRemapping) {
    206   TokenFeatureExtractorOptions options;
    207   options.num_buckets = 1000;
    208   options.chargram_orders = std::vector<int>{1, 2};
    209   options.remap_digits = true;
    210   options.unicode_aware_features = false;
    211   TokenFeatureExtractor extractor(options);
    212 
    213   std::vector<int> sparse_features;
    214   std::vector<float> dense_features;
    215   extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
    216                     &dense_features);
    217 
    218   std::vector<int> sparse_features2;
    219   extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
    220                     &dense_features);
    221   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
    222 
    223   extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
    224                     &dense_features);
    225   EXPECT_THAT(sparse_features,
    226               testing::Not(testing::ElementsAreArray(sparse_features2)));
    227 }
    228 
    229 TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) {
    230   TokenFeatureExtractorOptions options;
    231   options.num_buckets = 1000;
    232   options.chargram_orders = std::vector<int>{1, 2};
    233   options.remap_digits = true;
    234   options.unicode_aware_features = true;
    235   TokenFeatureExtractor extractor(options);
    236 
    237   std::vector<int> sparse_features;
    238   std::vector<float> dense_features;
    239   extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
    240                     &dense_features);
    241 
    242   std::vector<int> sparse_features2;
    243   extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
    244                     &dense_features);
    245   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
    246 
    247   extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
    248                     &dense_features);
    249   EXPECT_THAT(sparse_features,
    250               testing::Not(testing::ElementsAreArray(sparse_features2)));
    251 }
    252 
    253 TEST(TokenFeatureExtractorTest, LowercaseAscii) {
    254   TokenFeatureExtractorOptions options;
    255   options.num_buckets = 1000;
    256   options.chargram_orders = std::vector<int>{1, 2};
    257   options.lowercase_tokens = true;
    258   options.unicode_aware_features = false;
    259   TokenFeatureExtractor extractor(options);
    260 
    261   std::vector<int> sparse_features;
    262   std::vector<float> dense_features;
    263   extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features,
    264                     &dense_features);
    265 
    266   std::vector<int> sparse_features2;
    267   extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2,
    268                     &dense_features);
    269   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
    270 
    271   extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2,
    272                     &dense_features);
    273   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
    274 }
    275 
    276 TEST(TokenFeatureExtractorTest, LowercaseUnicode) {
    277   TokenFeatureExtractorOptions options;
    278   options.num_buckets = 1000;
    279   options.chargram_orders = std::vector<int>{1, 2};
    280   options.lowercase_tokens = true;
    281   options.unicode_aware_features = true;
    282   TokenFeatureExtractor extractor(options);
    283 
    284   std::vector<int> sparse_features;
    285   std::vector<float> dense_features;
    286   extractor.Extract(Token{"", 0, 6}, true, &sparse_features, &dense_features);
    287 
    288   std::vector<int> sparse_features2;
    289   extractor.Extract(Token{"", 0, 6}, true, &sparse_features2,
    290                     &dense_features);
    291   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
    292 }
    293 
    294 TEST(TokenFeatureExtractorTest, RegexFeatures) {
    295   TokenFeatureExtractorOptions options;
    296   options.num_buckets = 1000;
    297   options.chargram_orders = std::vector<int>{1, 2};
    298   options.remap_digits = false;
    299   options.unicode_aware_features = false;
    300   options.regexp_features.push_back("^[a-z]+$");  // all lower case.
    301   options.regexp_features.push_back("^[0-9]+$");  // all digits.
    302   TokenFeatureExtractor extractor(options);
    303 
    304   std::vector<int> sparse_features;
    305   std::vector<float> dense_features;
    306   extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
    307                     &dense_features);
    308   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
    309 
    310   dense_features.clear();
    311   extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
    312                     &dense_features);
    313   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
    314 
    315   dense_features.clear();
    316   extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
    317                     &dense_features);
    318   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
    319 
    320   dense_features.clear();
    321   extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
    322                     &dense_features);
    323   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
    324 }
    325 
    326 TEST(TokenFeatureExtractorTest, ExtractTooLongWord) {
    327   TokenFeatureExtractorOptions options;
    328   options.num_buckets = 1000;
    329   options.chargram_orders = std::vector<int>{22};
    330   options.extract_case_feature = true;
    331   options.unicode_aware_features = true;
    332   options.extract_selection_mask_feature = true;
    333   TestingTokenFeatureExtractor extractor(options);
    334 
    335   // Test that this runs. ASAN should catch problems.
    336   std::vector<int> sparse_features;
    337   std::vector<float> dense_features;
    338   extractor.Extract(Token{"abcdefghijklmnopqstuvwxyz", 0, 0}, true,
    339                     &sparse_features, &dense_features);
    340 
    341   EXPECT_THAT(sparse_features,
    342               testing::ElementsAreArray({
    343                   // clang-format off
    344                   extractor.HashToken("^abcdefghij\1qstuvwxyz"),
    345                   extractor.HashToken("abcdefghij\1qstuvwxyz$"),
    346                   // clang-format on
    347               }));
    348 }
    349 
    350 TEST(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) {
    351   TokenFeatureExtractorOptions options;
    352   options.num_buckets = 1000;
    353   options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5};
    354   options.extract_case_feature = true;
    355   options.unicode_aware_features = true;
    356   options.extract_selection_mask_feature = true;
    357   TestingTokenFeatureExtractor extractor_unicode(options);
    358 
    359   options.unicode_aware_features = false;
    360   TestingTokenFeatureExtractor extractor_ascii(options);
    361 
    362   for (const std::string& input :
    363        {"https://www.abcdefgh.com/in/xxxkkkvayio",
    364         "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil",
    365         "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd",
    366         "x", "Hello", "Hey,", "Hi", ""}) {
    367     std::vector<int> sparse_features_unicode;
    368     std::vector<float> dense_features_unicode;
    369     extractor_unicode.Extract(Token{input, 0, 0}, true,
    370                               &sparse_features_unicode,
    371                               &dense_features_unicode);
    372 
    373     std::vector<int> sparse_features_ascii;
    374     std::vector<float> dense_features_ascii;
    375     extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
    376                             &dense_features_ascii);
    377 
    378     EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
    379     EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input;
    380   }
    381 }
    382 
    383 TEST(TokenFeatureExtractorTest, ExtractForPadToken) {
    384   TokenFeatureExtractorOptions options;
    385   options.num_buckets = 1000;
    386   options.chargram_orders = std::vector<int>{1, 2};
    387   options.extract_case_feature = true;
    388   options.unicode_aware_features = false;
    389   options.extract_selection_mask_feature = true;
    390 
    391   TestingTokenFeatureExtractor extractor(options);
    392 
    393   std::vector<int> sparse_features;
    394   std::vector<float> dense_features;
    395 
    396   extractor.Extract(Token(), false, &sparse_features, &dense_features);
    397 
    398   EXPECT_THAT(sparse_features,
    399               testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
    400   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
    401 }
    402 
    403 }  // namespace
    404 }  // namespace libtextclassifier
    405