1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "token-feature-extractor.h" 18 19 #include "gmock/gmock.h" 20 #include "gtest/gtest.h" 21 22 namespace libtextclassifier2 { 23 namespace { 24 25 class TestingTokenFeatureExtractor : public TokenFeatureExtractor { 26 public: 27 using TokenFeatureExtractor::HashToken; 28 using TokenFeatureExtractor::TokenFeatureExtractor; 29 }; 30 31 TEST(TokenFeatureExtractorTest, ExtractAscii) { 32 TokenFeatureExtractorOptions options; 33 options.num_buckets = 1000; 34 options.chargram_orders = std::vector<int>{1, 2, 3}; 35 options.extract_case_feature = true; 36 options.unicode_aware_features = false; 37 options.extract_selection_mask_feature = true; 38 CREATE_UNILIB_FOR_TESTING 39 TestingTokenFeatureExtractor extractor(options, unilib); 40 41 std::vector<int> sparse_features; 42 std::vector<float> dense_features; 43 44 extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features, 45 &dense_features); 46 47 EXPECT_THAT(sparse_features, 48 testing::ElementsAreArray({ 49 // clang-format off 50 extractor.HashToken("H"), 51 extractor.HashToken("e"), 52 extractor.HashToken("l"), 53 extractor.HashToken("l"), 54 extractor.HashToken("o"), 55 extractor.HashToken("^H"), 56 extractor.HashToken("He"), 57 extractor.HashToken("el"), 58 extractor.HashToken("ll"), 59 extractor.HashToken("lo"), 60 extractor.HashToken("o$"), 61 extractor.HashToken("^He"), 62 extractor.HashToken("Hel"), 63 extractor.HashToken("ell"), 64 extractor.HashToken("llo"), 65 extractor.HashToken("lo$") 66 // clang-format on 67 })); 68 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0})); 69 70 sparse_features.clear(); 71 dense_features.clear(); 72 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, 73 &dense_features); 74 75 EXPECT_THAT(sparse_features, 76 testing::ElementsAreArray({ 77 // clang-format off 78 extractor.HashToken("w"), 79 extractor.HashToken("o"), 80 extractor.HashToken("r"), 81 extractor.HashToken("l"), 82 extractor.HashToken("d"), 83 extractor.HashToken("!"), 84 extractor.HashToken("^w"), 85 extractor.HashToken("wo"), 86 extractor.HashToken("or"), 87 extractor.HashToken("rl"), 88 extractor.HashToken("ld"), 89 extractor.HashToken("d!"), 90 extractor.HashToken("!$"), 91 extractor.HashToken("^wo"), 92 extractor.HashToken("wor"), 93 extractor.HashToken("orl"), 94 extractor.HashToken("rld"), 95 extractor.HashToken("ld!"), 96 extractor.HashToken("d!$"), 97 // clang-format on 98 })); 99 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0})); 100 } 101 102 TEST(TokenFeatureExtractorTest, ExtractAsciiNoChargrams) { 103 TokenFeatureExtractorOptions options; 104 options.num_buckets = 1000; 105 options.chargram_orders = std::vector<int>{}; 106 options.extract_case_feature = true; 107 options.unicode_aware_features = false; 108 options.extract_selection_mask_feature = true; 109 CREATE_UNILIB_FOR_TESTING 110 TestingTokenFeatureExtractor extractor(options, unilib); 111 112 std::vector<int> sparse_features; 113 std::vector<float> dense_features; 114 115 extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features, 116 &dense_features); 117 118 EXPECT_THAT(sparse_features, 119 testing::ElementsAreArray({extractor.HashToken("^Hello$")})); 120 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0})); 121 122 sparse_features.clear(); 123 dense_features.clear(); 124 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, 125 &dense_features); 126 127 EXPECT_THAT(sparse_features, 128 testing::ElementsAreArray({extractor.HashToken("^world!$")})); 129 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0})); 130 } 131 132 TEST(TokenFeatureExtractorTest, ExtractUnicode) { 133 TokenFeatureExtractorOptions options; 134 options.num_buckets = 1000; 135 options.chargram_orders = std::vector<int>{1, 2, 3}; 136 options.extract_case_feature = true; 137 options.unicode_aware_features = true; 138 options.extract_selection_mask_feature = true; 139 CREATE_UNILIB_FOR_TESTING 140 TestingTokenFeatureExtractor extractor(options, unilib); 141 142 std::vector<int> sparse_features; 143 std::vector<float> dense_features; 144 145 extractor.Extract(Token{"Hll", 0, 5}, true, &sparse_features, 146 &dense_features); 147 148 EXPECT_THAT(sparse_features, 149 testing::ElementsAreArray({ 150 // clang-format off 151 extractor.HashToken("H"), 152 extractor.HashToken(""), 153 extractor.HashToken("l"), 154 extractor.HashToken("l"), 155 extractor.HashToken(""), 156 extractor.HashToken("^H"), 157 extractor.HashToken("H"), 158 extractor.HashToken("l"), 159 extractor.HashToken("ll"), 160 extractor.HashToken("l"), 161 extractor.HashToken("$"), 162 extractor.HashToken("^H"), 163 extractor.HashToken("Hl"), 164 extractor.HashToken("ll"), 165 extractor.HashToken("ll"), 166 extractor.HashToken("l$") 167 // clang-format on 168 })); 169 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0})); 170 171 sparse_features.clear(); 172 dense_features.clear(); 173 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, 174 &dense_features); 175 176 EXPECT_THAT(sparse_features, 177 testing::ElementsAreArray({ 178 // clang-format off 179 extractor.HashToken("w"), 180 extractor.HashToken("o"), 181 extractor.HashToken("r"), 182 extractor.HashToken("l"), 183 extractor.HashToken("d"), 184 extractor.HashToken("!"), 185 extractor.HashToken("^w"), 186 extractor.HashToken("wo"), 187 extractor.HashToken("or"), 188 extractor.HashToken("rl"), 189 extractor.HashToken("ld"), 190 extractor.HashToken("d!"), 191 extractor.HashToken("!$"), 192 extractor.HashToken("^wo"), 193 extractor.HashToken("wor"), 194 extractor.HashToken("orl"), 195 extractor.HashToken("rld"), 196 extractor.HashToken("ld!"), 197 extractor.HashToken("d!$"), 198 // clang-format on 199 })); 200 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); 201 } 202 203 TEST(TokenFeatureExtractorTest, ExtractUnicodeNoChargrams) { 204 TokenFeatureExtractorOptions options; 205 options.num_buckets = 1000; 206 options.chargram_orders = std::vector<int>{}; 207 options.extract_case_feature = true; 208 options.unicode_aware_features = true; 209 options.extract_selection_mask_feature = true; 210 CREATE_UNILIB_FOR_TESTING 211 TestingTokenFeatureExtractor extractor(options, unilib); 212 213 std::vector<int> sparse_features; 214 std::vector<float> dense_features; 215 216 extractor.Extract(Token{"Hll", 0, 5}, true, &sparse_features, 217 &dense_features); 218 219 EXPECT_THAT(sparse_features, 220 testing::ElementsAreArray({extractor.HashToken("^Hll$")})); 221 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0})); 222 223 sparse_features.clear(); 224 dense_features.clear(); 225 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, 226 &dense_features); 227 228 EXPECT_THAT(sparse_features, testing::ElementsAreArray({ 229 extractor.HashToken("^world!$"), 230 })); 231 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); 232 } 233 234 #ifdef LIBTEXTCLASSIFIER_TEST_ICU 235 TEST(TokenFeatureExtractorTest, ICUCaseFeature) { 236 TokenFeatureExtractorOptions options; 237 options.num_buckets = 1000; 238 options.chargram_orders = std::vector<int>{1, 2}; 239 options.extract_case_feature = true; 240 options.unicode_aware_features = true; 241 options.extract_selection_mask_feature = false; 242 CREATE_UNILIB_FOR_TESTING 243 TestingTokenFeatureExtractor extractor(options, unilib); 244 245 std::vector<int> sparse_features; 246 std::vector<float> dense_features; 247 extractor.Extract(Token{"Hll", 0, 5}, true, &sparse_features, 248 &dense_features); 249 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0})); 250 251 sparse_features.clear(); 252 dense_features.clear(); 253 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, 254 &dense_features); 255 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0})); 256 257 sparse_features.clear(); 258 dense_features.clear(); 259 extractor.Extract(Token{"", 23, 29}, false, &sparse_features, 260 &dense_features); 261 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0})); 262 263 sparse_features.clear(); 264 dense_features.clear(); 265 extractor.Extract(Token{"", 23, 29}, false, &sparse_features, 266 &dense_features); 267 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0})); 268 } 269 #endif 270 271 TEST(TokenFeatureExtractorTest, DigitRemapping) { 272 TokenFeatureExtractorOptions options; 273 options.num_buckets = 1000; 274 options.chargram_orders = std::vector<int>{1, 2}; 275 options.remap_digits = true; 276 options.unicode_aware_features = false; 277 CREATE_UNILIB_FOR_TESTING 278 TestingTokenFeatureExtractor extractor(options, unilib); 279 280 std::vector<int> sparse_features; 281 std::vector<float> dense_features; 282 extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features, 283 &dense_features); 284 285 std::vector<int> sparse_features2; 286 extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2, 287 &dense_features); 288 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); 289 290 extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2, 291 &dense_features); 292 EXPECT_THAT(sparse_features, 293 testing::Not(testing::ElementsAreArray(sparse_features2))); 294 } 295 296 TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) { 297 TokenFeatureExtractorOptions options; 298 options.num_buckets = 1000; 299 options.chargram_orders = std::vector<int>{1, 2}; 300 options.remap_digits = true; 301 options.unicode_aware_features = true; 302 CREATE_UNILIB_FOR_TESTING 303 TestingTokenFeatureExtractor extractor(options, unilib); 304 305 std::vector<int> sparse_features; 306 std::vector<float> dense_features; 307 extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features, 308 &dense_features); 309 310 std::vector<int> sparse_features2; 311 extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2, 312 &dense_features); 313 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); 314 315 extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2, 316 &dense_features); 317 EXPECT_THAT(sparse_features, 318 testing::Not(testing::ElementsAreArray(sparse_features2))); 319 } 320 321 TEST(TokenFeatureExtractorTest, LowercaseAscii) { 322 TokenFeatureExtractorOptions options; 323 options.num_buckets = 1000; 324 options.chargram_orders = std::vector<int>{1, 2}; 325 options.lowercase_tokens = true; 326 options.unicode_aware_features = false; 327 CREATE_UNILIB_FOR_TESTING 328 TestingTokenFeatureExtractor extractor(options, unilib); 329 330 std::vector<int> sparse_features; 331 std::vector<float> dense_features; 332 extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features, 333 &dense_features); 334 335 std::vector<int> sparse_features2; 336 extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2, 337 &dense_features); 338 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); 339 340 extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2, 341 &dense_features); 342 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); 343 } 344 345 #ifdef LIBTEXTCLASSIFIER_TEST_ICU 346 TEST(TokenFeatureExtractorTest, LowercaseUnicode) { 347 TokenFeatureExtractorOptions options; 348 options.num_buckets = 1000; 349 options.chargram_orders = std::vector<int>{1, 2}; 350 options.lowercase_tokens = true; 351 options.unicode_aware_features = true; 352 CREATE_UNILIB_FOR_TESTING 353 TestingTokenFeatureExtractor extractor(options, unilib); 354 355 std::vector<int> sparse_features; 356 std::vector<float> dense_features; 357 extractor.Extract(Token{"", 0, 6}, true, &sparse_features, &dense_features); 358 359 std::vector<int> sparse_features2; 360 extractor.Extract(Token{"", 0, 6}, true, &sparse_features2, 361 &dense_features); 362 EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2)); 363 } 364 #endif 365 366 #ifdef LIBTEXTCLASSIFIER_TEST_ICU 367 TEST(TokenFeatureExtractorTest, RegexFeatures) { 368 TokenFeatureExtractorOptions options; 369 options.num_buckets = 1000; 370 options.chargram_orders = std::vector<int>{1, 2}; 371 options.remap_digits = false; 372 options.unicode_aware_features = false; 373 options.regexp_features.push_back("^[a-z]+$"); // all lower case. 374 options.regexp_features.push_back("^[0-9]+$"); // all digits. 375 CREATE_UNILIB_FOR_TESTING 376 TestingTokenFeatureExtractor extractor(options, unilib); 377 378 std::vector<int> sparse_features; 379 std::vector<float> dense_features; 380 extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features, 381 &dense_features); 382 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); 383 384 dense_features.clear(); 385 extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features, 386 &dense_features); 387 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0})); 388 389 dense_features.clear(); 390 extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features, 391 &dense_features); 392 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0})); 393 394 dense_features.clear(); 395 extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features, 396 &dense_features); 397 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0})); 398 } 399 #endif 400 401 TEST(TokenFeatureExtractorTest, ExtractTooLongWord) { 402 TokenFeatureExtractorOptions options; 403 options.num_buckets = 1000; 404 options.chargram_orders = std::vector<int>{22}; 405 options.extract_case_feature = true; 406 options.unicode_aware_features = true; 407 options.extract_selection_mask_feature = true; 408 CREATE_UNILIB_FOR_TESTING 409 TestingTokenFeatureExtractor extractor(options, unilib); 410 411 // Test that this runs. ASAN should catch problems. 412 std::vector<int> sparse_features; 413 std::vector<float> dense_features; 414 extractor.Extract(Token{"abcdefghijklmnopqstuvwxyz", 0, 0}, true, 415 &sparse_features, &dense_features); 416 417 EXPECT_THAT(sparse_features, 418 testing::ElementsAreArray({ 419 // clang-format off 420 extractor.HashToken("^abcdefghij\1qstuvwxyz"), 421 extractor.HashToken("abcdefghij\1qstuvwxyz$"), 422 // clang-format on 423 })); 424 } 425 426 TEST(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) { 427 TokenFeatureExtractorOptions options; 428 options.num_buckets = 1000; 429 options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5}; 430 options.extract_case_feature = true; 431 options.unicode_aware_features = true; 432 options.extract_selection_mask_feature = true; 433 434 CREATE_UNILIB_FOR_TESTING 435 TestingTokenFeatureExtractor extractor_unicode(options, unilib); 436 437 options.unicode_aware_features = false; 438 TestingTokenFeatureExtractor extractor_ascii(options, unilib); 439 440 for (const std::string& input : 441 {"https://www.abcdefgh.com/in/xxxkkkvayio", 442 "https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil", 443 "asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd", 444 "x", "Hello", "Hey,", "Hi", ""}) { 445 std::vector<int> sparse_features_unicode; 446 std::vector<float> dense_features_unicode; 447 extractor_unicode.Extract(Token{input, 0, 0}, true, 448 &sparse_features_unicode, 449 &dense_features_unicode); 450 451 std::vector<int> sparse_features_ascii; 452 std::vector<float> dense_features_ascii; 453 extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii, 454 &dense_features_ascii); 455 456 EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input; 457 EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input; 458 } 459 } 460 461 TEST(TokenFeatureExtractorTest, ExtractForPadToken) { 462 TokenFeatureExtractorOptions options; 463 options.num_buckets = 1000; 464 options.chargram_orders = std::vector<int>{1, 2}; 465 options.extract_case_feature = true; 466 options.unicode_aware_features = false; 467 options.extract_selection_mask_feature = true; 468 469 CREATE_UNILIB_FOR_TESTING 470 TestingTokenFeatureExtractor extractor(options, unilib); 471 472 std::vector<int> sparse_features; 473 std::vector<float> dense_features; 474 475 extractor.Extract(Token(), false, &sparse_features, &dense_features); 476 477 EXPECT_THAT(sparse_features, 478 testing::ElementsAreArray({extractor.HashToken("<PAD>")})); 479 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0})); 480 } 481 482 TEST(TokenFeatureExtractorTest, ExtractFiltered) { 483 TokenFeatureExtractorOptions options; 484 options.num_buckets = 1000; 485 options.chargram_orders = std::vector<int>{1, 2, 3}; 486 options.extract_case_feature = true; 487 options.unicode_aware_features = false; 488 options.extract_selection_mask_feature = true; 489 options.allowed_chargrams.insert("^H"); 490 options.allowed_chargrams.insert("ll"); 491 options.allowed_chargrams.insert("llo"); 492 options.allowed_chargrams.insert("w"); 493 options.allowed_chargrams.insert("!"); 494 options.allowed_chargrams.insert("\xc4"); // UTF8 control character. 495 496 CREATE_UNILIB_FOR_TESTING 497 TestingTokenFeatureExtractor extractor(options, unilib); 498 499 std::vector<int> sparse_features; 500 std::vector<float> dense_features; 501 502 extractor.Extract(Token{"Hllo", 0, 5}, true, &sparse_features, 503 &dense_features); 504 505 EXPECT_THAT(sparse_features, 506 testing::ElementsAreArray({ 507 // clang-format off 508 0, 509 extractor.HashToken("\xc4"), 510 0, 511 0, 512 0, 513 0, 514 extractor.HashToken("^H"), 515 0, 516 0, 517 0, 518 extractor.HashToken("ll"), 519 0, 520 0, 521 0, 522 0, 523 0, 524 0, 525 extractor.HashToken("llo"), 526 0 527 // clang-format on 528 })); 529 EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0})); 530 531 sparse_features.clear(); 532 dense_features.clear(); 533 extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features, 534 &dense_features); 535 536 EXPECT_THAT(sparse_features, testing::ElementsAreArray({ 537 // clang-format off 538 extractor.HashToken("w"), 539 0, 540 0, 541 0, 542 0, 543 extractor.HashToken("!"), 544 0, 545 0, 546 0, 547 0, 548 0, 549 0, 550 0, 551 0, 552 0, 553 0, 554 0, 555 0, 556 0, 557 // clang-format on 558 })); 559 EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0})); 560 EXPECT_EQ(extractor.HashToken("<PAD>"), 1); 561 } 562 563 } // namespace 564 } // namespace libtextclassifier2 565