1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "smartselect/token-feature-extractor.h" 18 19 #include "gmock/gmock.h" 20 #include "gtest/gtest.h" 21 22 namespace libtextclassifier { 23 namespace { 24 25 class TestingTokenFeatureExtractor : public TokenFeatureExtractor { 26 public: 27 using TokenFeatureExtractor::TokenFeatureExtractor; 28 using TokenFeatureExtractor::HashToken; 29 }; 30 31 TEST(TokenFeatureExtractorTest, ExtractAscii) { 32 TokenFeatureExtractorOptions options; 33 options.num_buckets = 1000; 34 options.chargram_orders = std::vector<int>{1, 2, 3}; 35 options.extract_case_feature = true; 36 options.unicode_aware_features = false; 37 options.extract_selection_mask_feature = true; 38 TestingTokenFeatureExtractor extractor(options); 39 40 std::vector<int> sparse_features; 41 std::vector<float> dense_features; 42 43 extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features, 44 &dense_features); 45 46 EXPECT_THAT(sparse_features, 47 testing::ElementsAreArray({ 48 // clang-format off 49 extractor.HashToken("H"), 50 extractor.HashToken("e"), 51 extractor.HashToken("l"), 52 extractor.HashToken("l"), 53 extractor.HashToken("o"), 54 extractor.HashToken("^H"), 55 extractor.HashToken("He"), 56 extractor.HashToken("el"), 57 extractor.HashToken("ll"), 58 extractor.HashToken("lo"), 59 extractor.HashToken("o$"), 60 extractor.HashToken("^He"), 61 extractor.HashToken("Hel"), 62 extractor.HashToken("ell"), 63 extractor.HashToken("llo"), 64 extractor.HashToken("lo$") 65 // clang-format on 66 })); 67 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0})); 68 69 sparse_features.clear(); 70 dense_features.clear(); 71 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, 72 &dense_features); 73 74 EXPECT_THAT(sparse_features, 75 testing::ElementsAreArray({ 76 // clang-format off 77 extractor.HashToken("w"), 78 extractor.HashToken("o"), 79 extractor.HashToken("r"), 80 extractor.HashToken("l"), 81 extractor.HashToken("d"), 82 extractor.HashToken("!"), 83 extractor.HashToken("^w"), 84 extractor.HashToken("wo"), 85 extractor.HashToken("or"), 86 extractor.HashToken("rl"), 87 extractor.HashToken("ld"), 88 extractor.HashToken("d!"), 89 extractor.HashToken("!$"), 90 extractor.HashToken("^wo"), 91 extractor.HashToken("wor"), 92 extractor.HashToken("orl"), 93 extractor.HashToken("rld"), 94 extractor.HashToken("ld!"), 95 extractor.HashToken("d!$"), 96 // clang-format on 97 })); 98 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0})); 99 } 100 101 TEST(TokenFeatureExtractorTest, ExtractUnicode) { 102 TokenFeatureExtractorOptions options; 103 options.num_buckets = 1000; 104 options.chargram_orders = std::vector<int>{1, 2, 3}; 105 options.extract_case_feature = true; 106 options.unicode_aware_features = true; 107 options.extract_selection_mask_feature = true; 108 TestingTokenFeatureExtractor extractor(options); 109 110 std::vector<int> sparse_features; 111 std::vector<float> dense_features; 112 113 extractor.Extract(Token{"Hll", 0, 5}, true, &sparse_features, 114 &dense_features); 115 116 EXPECT_THAT(sparse_features, 117 testing::ElementsAreArray({ 118 // clang-format off 119 extractor.HashToken("H"), 120 extractor.HashToken(""), 121 extractor.HashToken("l"), 122 extractor.HashToken("l"), 123 extractor.HashToken(""), 124 extractor.HashToken("^H"), 125 extractor.HashToken("H"), 126 extractor.HashToken("l"), 127 extractor.HashToken("ll"), 128 extractor.HashToken("l"), 129 extractor.HashToken("$"), 130 extractor.HashToken("^H"), 131 extractor.HashToken("Hl"), 132 extractor.HashToken("ll"), 133 extractor.HashToken("ll"), 134 extractor.HashToken("l$") 135 // clang-format on 136 })); 137 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0})); 138 139 sparse_features.clear(); 140 dense_features.clear(); 141 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, 142 &dense_features); 143 144 EXPECT_THAT(sparse_features, 145 testing::ElementsAreArray({ 146 // clang-format off 147 extractor.HashToken("w"), 148 extractor.HashToken("o"), 149 extractor.HashToken("r"), 150 extractor.HashToken("l"), 151 extractor.HashToken("d"), 152 extractor.HashToken("!"), 153 extractor.HashToken("^w"), 154 extractor.HashToken("wo"), 155 extractor.HashToken("or"), 156 extractor.HashToken("rl"), 157 extractor.HashToken("ld"), 158 extractor.HashToken("d!"), 159 extractor.HashToken("!$"), 160 extractor.HashToken("^wo"), 161 extractor.HashToken("wor"), 162 extractor.HashToken("orl"), 163 extractor.HashToken("rld"), 164 extractor.HashToken("ld!"), 165 extractor.HashToken("d!$"), 166 // clang-format on 167 })); 168 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); 169 } 170 171 TEST(TokenFeatureExtractorTest, ICUCaseFeature) { 172 TokenFeatureExtractorOptions options; 173 options.num_buckets = 1000; 174 options.chargram_orders = std::vector<int>{1, 2}; 175 options.extract_case_feature = true; 176 options.unicode_aware_features = true; 177 options.extract_selection_mask_feature = false; 178 TokenFeatureExtractor extractor(options); 179 180 std::vector<int> sparse_features; 181 std::vector<float> dense_features; 182 extractor.Extract(Token{"Hll", 0, 5}, true, &sparse_features, 183 &dense_features); 184 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0})); 185 186 sparse_features.clear(); 187 dense_features.clear(); 188 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, 189 &dense_features); 190 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0})); 191 192 sparse_features.clear(); 193 dense_features.clear(); 194 extractor.Extract(Token{"", 23, 29}, false, &sparse_features, 195 &dense_features); 196 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0})); 197 198 sparse_features.clear(); 199 dense_features.clear(); 200 extractor.Extract(Token{"", 23, 29}, false, &sparse_features, 201 &dense_features); 202 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0})); 203 } 204 205 TEST(TokenFeatureExtractorTest, DigitRemapping) { 206 TokenFeatureExtractorOptions options; 207 options.num_buckets = 1000; 208 options.chargram_orders = std::vector<int>{1, 2}; 209 options.remap_digits = true; 210 options.unicode_aware_features = false; 211 TokenFeatureExtractor extractor(options); 212 213 std::vector<int> sparse_features; 214 std::vector<float> dense_features; 215 extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features, 216 &dense_features); 217 218 std::vector<int> sparse_features2; 219 extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2, 220 &dense_features); 221 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); 222 223 extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2, 224 &dense_features); 225 EXPECT_THAT(sparse_features, 226 testing::Not(testing::ElementsAreArray(sparse_features2))); 227 } 228 229 TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) { 230 TokenFeatureExtractorOptions options; 231 options.num_buckets = 1000; 232 options.chargram_orders = std::vector<int>{1, 2}; 233 options.remap_digits = true; 234 options.unicode_aware_features = true; 235 TokenFeatureExtractor extractor(options); 236 237 std::vector<int> sparse_features; 238 std::vector<float> dense_features; 239 extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features, 240 &dense_features); 241 242 std::vector<int> sparse_features2; 243 extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2, 244 &dense_features); 245 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); 246 247 extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2, 248 &dense_features); 249 EXPECT_THAT(sparse_features, 250 testing::Not(testing::ElementsAreArray(sparse_features2))); 251 } 252 253 TEST(TokenFeatureExtractorTest, LowercaseAscii) { 254 TokenFeatureExtractorOptions options; 255 options.num_buckets = 1000; 256 options.chargram_orders = std::vector<int>{1, 2}; 257 options.lowercase_tokens = true; 258 options.unicode_aware_features = false; 259 TokenFeatureExtractor extractor(options); 260 261 std::vector<int> sparse_features; 262 std::vector<float> dense_features; 263 extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features, 264 &dense_features); 265 266 std::vector<int> sparse_features2; 267 extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2, 268 &dense_features); 269 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); 270 271 extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2, 272 &dense_features); 273 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); 274 } 275 276 TEST(TokenFeatureExtractorTest, LowercaseUnicode) { 277 TokenFeatureExtractorOptions options; 278 options.num_buckets = 1000; 279 options.chargram_orders = std::vector<int>{1, 2}; 280 options.lowercase_tokens = true; 281 options.unicode_aware_features = true; 282 TokenFeatureExtractor extractor(options); 283 284 std::vector<int> sparse_features; 285 std::vector<float> dense_features; 286 extractor.Extract(Token{"", 0, 6}, true, &sparse_features, &dense_features); 287 288 std::vector<int> sparse_features2; 289 extractor.Extract(Token{"", 0, 6}, true, &sparse_features2, 290 &dense_features); 291 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); 292 } 293 294 TEST(TokenFeatureExtractorTest, RegexFeatures) { 295 TokenFeatureExtractorOptions options; 296 options.num_buckets = 1000; 297 options.chargram_orders = std::vector<int>{1, 2}; 298 options.remap_digits = false; 299 options.unicode_aware_features = false; 300 options.regexp_features.push_back("^[a-z]+$"); // all lower case. 301 options.regexp_features.push_back("^[0-9]+$"); // all digits. 302 TokenFeatureExtractor extractor(options); 303 304 std::vector<int> sparse_features; 305 std::vector<float> dense_features; 306 extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features, 307 &dense_features); 308 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); 309 310 dense_features.clear(); 311 extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features, 312 &dense_features); 313 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0})); 314 315 dense_features.clear(); 316 extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features, 317 &dense_features); 318 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); 319 320 dense_features.clear(); 321 extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features, 322 &dense_features); 323 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0})); 324 } 325 326 TEST(TokenFeatureExtractorTest, ExtractTooLongWord) { 327 TokenFeatureExtractorOptions options; 328 options.num_buckets = 1000; 329 options.chargram_orders = std::vector<int>{22}; 330 options.extract_case_feature = true; 331 options.unicode_aware_features = true; 332 options.extract_selection_mask_feature = true; 333 TestingTokenFeatureExtractor extractor(options); 334 335 // Test that this runs. ASAN should catch problems. 336 std::vector<int> sparse_features; 337 std::vector<float> dense_features; 338 extractor.Extract(Token{"abcdefghijklmnopqstuvwxyz", 0, 0}, true, 339 &sparse_features, &dense_features); 340 341 EXPECT_THAT(sparse_features, 342 testing::ElementsAreArray({ 343 // clang-format off 344 extractor.HashToken("^abcdefghij\1qstuvwxyz"), 345 extractor.HashToken("abcdefghij\1qstuvwxyz$"), 346 // clang-format on 347 })); 348 } 349 350 TEST(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) { 351 TokenFeatureExtractorOptions options; 352 options.num_buckets = 1000; 353 options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5}; 354 options.extract_case_feature = true; 355 options.unicode_aware_features = true; 356 options.extract_selection_mask_feature = true; 357 TestingTokenFeatureExtractor extractor_unicode(options); 358 359 options.unicode_aware_features = false; 360 TestingTokenFeatureExtractor extractor_ascii(options); 361 362 for (const std::string& input : 363 {"https://www.abcdefgh.com/in/xxxkkkvayio", 364 "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil", 365 "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd", 366 "x", "Hello", "Hey,", "Hi", ""}) { 367 std::vector<int> sparse_features_unicode; 368 std::vector<float> dense_features_unicode; 369 extractor_unicode.Extract(Token{input, 0, 0}, true, 370 &sparse_features_unicode, 371 &dense_features_unicode); 372 373 std::vector<int> sparse_features_ascii; 374 std::vector<float> dense_features_ascii; 375 extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii, 376 &dense_features_ascii); 377 378 EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input; 379 EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input; 380 } 381 } 382 383 TEST(TokenFeatureExtractorTest, ExtractForPadToken) { 384 TokenFeatureExtractorOptions options; 385 options.num_buckets = 1000; 386 options.chargram_orders = std::vector<int>{1, 2}; 387 options.extract_case_feature = true; 388 options.unicode_aware_features = false; 389 options.extract_selection_mask_feature = true; 390 391 TestingTokenFeatureExtractor extractor(options); 392 393 std::vector<int> sparse_features; 394 std::vector<float> dense_features; 395 396 extractor.Extract(Token(), false, &sparse_features, &dense_features); 397 398 EXPECT_THAT(sparse_features, 399 testing::ElementsAreArray({extractor.HashToken("<PAD>")})); 400 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0})); 401 } 402 403 } // namespace 404 } // namespace libtextclassifier 405