1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "annotator/feature-processor.h" 18 19 #include "annotator/model-executor.h" 20 #include "utils/tensor-view.h" 21 22 #include "gmock/gmock.h" 23 #include "gtest/gtest.h" 24 25 namespace libtextclassifier3 { 26 namespace { 27 28 using testing::ElementsAreArray; 29 using testing::FloatEq; 30 using testing::Matcher; 31 32 flatbuffers::DetachedBuffer PackFeatureProcessorOptions( 33 const FeatureProcessorOptionsT& options) { 34 flatbuffers::FlatBufferBuilder builder; 35 builder.Finish(CreateFeatureProcessorOptions(builder, &options)); 36 return builder.Release(); 37 } 38 39 template <typename T> 40 std::vector<T> Subvector(const std::vector<T>& vector, int start, int end) { 41 return std::vector<T>(vector.begin() + start, vector.begin() + end); 42 } 43 44 Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) { 45 std::vector<Matcher<float>> matchers; 46 for (const float value : values) { 47 matchers.push_back(FloatEq(value)); 48 } 49 return ElementsAreArray(matchers); 50 } 51 52 class TestingFeatureProcessor : public FeatureProcessor { 53 public: 54 using FeatureProcessor::CountIgnoredSpanBoundaryCodepoints; 55 using FeatureProcessor::FeatureProcessor; 56 using FeatureProcessor::SpanToLabel; 57 using FeatureProcessor::StripTokensFromOtherLines; 58 using FeatureProcessor::supported_codepoint_ranges_; 59 using FeatureProcessor::SupportedCodepointsRatio; 60 }; 61 62 // EmbeddingExecutor that always returns features based on 63 class FakeEmbeddingExecutor : public EmbeddingExecutor { 64 public: 65 bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, 66 int dest_size) const override { 67 TC3_CHECK_GE(dest_size, 4); 68 EXPECT_EQ(sparse_features.size(), 1); 69 dest[0] = sparse_features.data()[0]; 70 dest[1] = sparse_features.data()[0]; 71 dest[2] = -sparse_features.data()[0]; 72 dest[3] = -sparse_features.data()[0]; 73 return true; 74 } 75 76 private: 77 std::vector<float> storage_; 78 }; 79 80 class FeatureProcessorTest : public ::testing::Test { 81 protected: 82 FeatureProcessorTest() : INIT_UNILIB_FOR_TESTING(unilib_) {} 83 UniLib unilib_; 84 }; 85 86 TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) { 87 std::vector<Token> tokens{Token("Hll", 0, 5), 88 Token("fba@google.com", 6, 23), 89 Token("hee!", 24, 29)}; 90 91 internal::SplitTokensOnSelectionBoundaries({9, 12}, &tokens); 92 93 // clang-format off 94 EXPECT_THAT(tokens, ElementsAreArray( 95 {Token("Hll", 0, 5), 96 Token("f", 6, 9), 97 Token("ba", 9, 12), 98 Token("@google.com", 12, 23), 99 Token("hee!", 24, 29)})); 100 // clang-format on 101 } 102 103 TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesBegin) { 104 std::vector<Token> tokens{Token("Hll", 0, 5), 105 Token("fba@google.com", 6, 23), 106 Token("hee!", 24, 29)}; 107 108 internal::SplitTokensOnSelectionBoundaries({6, 12}, &tokens); 109 110 // clang-format off 111 EXPECT_THAT(tokens, ElementsAreArray( 112 {Token("Hll", 0, 5), 113 Token("fba", 6, 12), 114 Token("@google.com", 12, 23), 115 Token("hee!", 24, 29)})); 116 // clang-format on 117 } 118 119 TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesEnd) { 120 std::vector<Token> tokens{Token("Hll", 0, 5), 121 Token("fba@google.com", 6, 23), 122 Token("hee!", 24, 29)}; 123 124 internal::SplitTokensOnSelectionBoundaries({9, 23}, &tokens); 125 126 // clang-format off 127 EXPECT_THAT(tokens, ElementsAreArray( 128 {Token("Hll", 0, 5), 129 Token("f", 6, 9), 130 Token("ba@google.com", 9, 23), 131 Token("hee!", 24, 29)})); 132 // clang-format on 133 } 134 135 TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesWhole) { 136 std::vector<Token> tokens{Token("Hll", 0, 5), 137 Token("fba@google.com", 6, 23), 138 Token("hee!", 24, 29)}; 139 140 internal::SplitTokensOnSelectionBoundaries({6, 23}, &tokens); 141 142 // clang-format off 143 EXPECT_THAT(tokens, ElementsAreArray( 144 {Token("Hll", 0, 5), 145 Token("fba@google.com", 6, 23), 146 Token("hee!", 24, 29)})); 147 // clang-format on 148 } 149 150 TEST_F(FeatureProcessorTest, SplitTokensOnSelectionBoundariesCrossToken) { 151 std::vector<Token> tokens{Token("Hll", 0, 5), 152 Token("fba@google.com", 6, 23), 153 Token("hee!", 24, 29)}; 154 155 internal::SplitTokensOnSelectionBoundaries({2, 9}, &tokens); 156 157 // clang-format off 158 EXPECT_THAT(tokens, ElementsAreArray( 159 {Token("H", 0, 2), 160 Token("ll", 2, 5), 161 Token("f", 6, 9), 162 Token("ba@google.com", 9, 23), 163 Token("hee!", 24, 29)})); 164 // clang-format on 165 } 166 167 TEST_F(FeatureProcessorTest, KeepLineWithClickFirst) { 168 FeatureProcessorOptionsT options; 169 options.only_use_line_with_click = true; 170 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); 171 TestingFeatureProcessor feature_processor( 172 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), 173 &unilib_); 174 175 const std::string context = "Fist Lin\nScond Lin\nThid Lin"; 176 const CodepointSpan span = {0, 5}; 177 // clang-format off 178 std::vector<Token> tokens = {Token("Fist", 0, 5), 179 Token("Lin", 6, 10), 180 Token("Scond", 11, 17), 181 Token("Lin", 18, 22), 182 Token("Thid", 23, 28), 183 Token("Lin", 29, 33)}; 184 // clang-format on 185 186 // Keeps the first line. 187 feature_processor.StripTokensFromOtherLines(context, span, &tokens); 188 EXPECT_THAT(tokens, 189 ElementsAreArray({Token("Fist", 0, 5), Token("Lin", 6, 10)})); 190 } 191 192 TEST_F(FeatureProcessorTest, KeepLineWithClickSecond) { 193 FeatureProcessorOptionsT options; 194 options.only_use_line_with_click = true; 195 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); 196 TestingFeatureProcessor feature_processor( 197 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), 198 &unilib_); 199 200 const std::string context = "Fist Lin\nScond Lin\nThid Lin"; 201 const CodepointSpan span = {18, 22}; 202 // clang-format off 203 std::vector<Token> tokens = {Token("Fist", 0, 5), 204 Token("Lin", 6, 10), 205 Token("Scond", 11, 17), 206 Token("Lin", 18, 22), 207 Token("Thid", 23, 28), 208 Token("Lin", 29, 33)}; 209 // clang-format on 210 211 // Keeps the first line. 212 feature_processor.StripTokensFromOtherLines(context, span, &tokens); 213 EXPECT_THAT(tokens, ElementsAreArray( 214 {Token("Scond", 11, 17), Token("Lin", 18, 22)})); 215 } 216 217 TEST_F(FeatureProcessorTest, KeepLineWithClickThird) { 218 FeatureProcessorOptionsT options; 219 options.only_use_line_with_click = true; 220 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); 221 TestingFeatureProcessor feature_processor( 222 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), 223 &unilib_); 224 225 const std::string context = "Fist Lin\nScond Lin\nThid Lin"; 226 const CodepointSpan span = {24, 33}; 227 // clang-format off 228 std::vector<Token> tokens = {Token("Fist", 0, 5), 229 Token("Lin", 6, 10), 230 Token("Scond", 11, 17), 231 Token("Lin", 18, 22), 232 Token("Thid", 23, 28), 233 Token("Lin", 29, 33)}; 234 // clang-format on 235 236 // Keeps the first line. 237 feature_processor.StripTokensFromOtherLines(context, span, &tokens); 238 EXPECT_THAT(tokens, ElementsAreArray( 239 {Token("Thid", 23, 28), Token("Lin", 29, 33)})); 240 } 241 242 TEST_F(FeatureProcessorTest, KeepLineWithClickSecondWithPipe) { 243 FeatureProcessorOptionsT options; 244 options.only_use_line_with_click = true; 245 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); 246 TestingFeatureProcessor feature_processor( 247 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), 248 &unilib_); 249 250 const std::string context = "Fist Lin|Scond Lin\nThid Lin"; 251 const CodepointSpan span = {18, 22}; 252 // clang-format off 253 std::vector<Token> tokens = {Token("Fist", 0, 5), 254 Token("Lin", 6, 10), 255 Token("Scond", 11, 17), 256 Token("Lin", 18, 22), 257 Token("Thid", 23, 28), 258 Token("Lin", 29, 33)}; 259 // clang-format on 260 261 // Keeps the first line. 262 feature_processor.StripTokensFromOtherLines(context, span, &tokens); 263 EXPECT_THAT(tokens, ElementsAreArray( 264 {Token("Scond", 11, 17), Token("Lin", 18, 22)})); 265 } 266 267 TEST_F(FeatureProcessorTest, KeepLineWithCrosslineClick) { 268 FeatureProcessorOptionsT options; 269 options.only_use_line_with_click = true; 270 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); 271 TestingFeatureProcessor feature_processor( 272 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), 273 &unilib_); 274 275 const std::string context = "Fist Lin\nScond Lin\nThid Lin"; 276 const CodepointSpan span = {5, 23}; 277 // clang-format off 278 std::vector<Token> tokens = {Token("Fist", 0, 5), 279 Token("Lin", 6, 10), 280 Token("Scond", 18, 23), 281 Token("Lin", 19, 23), 282 Token("Thid", 23, 28), 283 Token("Lin", 29, 33)}; 284 // clang-format on 285 286 // Keeps the first line. 287 feature_processor.StripTokensFromOtherLines(context, span, &tokens); 288 EXPECT_THAT(tokens, ElementsAreArray( 289 {Token("Fist", 0, 5), Token("Lin", 6, 10), 290 Token("Scond", 18, 23), Token("Lin", 19, 23), 291 Token("Thid", 23, 28), Token("Lin", 29, 33)})); 292 } 293 294 TEST_F(FeatureProcessorTest, SpanToLabel) { 295 FeatureProcessorOptionsT options; 296 options.context_size = 1; 297 options.max_selection_span = 1; 298 options.snap_label_span_boundaries_to_containing_tokens = false; 299 300 options.tokenization_codepoint_config.emplace_back( 301 new TokenizationCodepointRangeT()); 302 auto& config = options.tokenization_codepoint_config.back(); 303 config->start = 32; 304 config->end = 33; 305 config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; 306 307 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); 308 TestingFeatureProcessor feature_processor( 309 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), 310 &unilib_); 311 std::vector<Token> tokens = feature_processor.Tokenize("one, two, three"); 312 ASSERT_EQ(3, tokens.size()); 313 int label; 314 ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label)); 315 EXPECT_EQ(kInvalidLabel, label); 316 ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label)); 317 EXPECT_NE(kInvalidLabel, label); 318 TokenSpan token_span; 319 feature_processor.LabelToTokenSpan(label, &token_span); 320 EXPECT_EQ(0, token_span.first); 321 EXPECT_EQ(0, token_span.second); 322 323 // Reconfigure with snapping enabled. 324 options.snap_label_span_boundaries_to_containing_tokens = true; 325 flatbuffers::DetachedBuffer options2_fb = 326 PackFeatureProcessorOptions(options); 327 TestingFeatureProcessor feature_processor2( 328 flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()), 329 &unilib_); 330 int label2; 331 ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2)); 332 EXPECT_EQ(label, label2); 333 ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2)); 334 EXPECT_EQ(label, label2); 335 ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2)); 336 EXPECT_EQ(label, label2); 337 338 // Cross a token boundary. 339 ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2)); 340 EXPECT_EQ(kInvalidLabel, label2); 341 ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2)); 342 EXPECT_EQ(kInvalidLabel, label2); 343 344 // Multiple tokens. 345 options.context_size = 2; 346 options.max_selection_span = 2; 347 flatbuffers::DetachedBuffer options3_fb = 348 PackFeatureProcessorOptions(options); 349 TestingFeatureProcessor feature_processor3( 350 flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()), 351 &unilib_); 352 tokens = feature_processor3.Tokenize("zero, one, two, three, four"); 353 ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2)); 354 EXPECT_NE(kInvalidLabel, label2); 355 feature_processor3.LabelToTokenSpan(label2, &token_span); 356 EXPECT_EQ(1, token_span.first); 357 EXPECT_EQ(0, token_span.second); 358 359 int label3; 360 ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3)); 361 EXPECT_EQ(label2, label3); 362 ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3)); 363 EXPECT_EQ(label2, label3); 364 ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3)); 365 EXPECT_EQ(label2, label3); 366 } 367 368 TEST_F(FeatureProcessorTest, SpanToLabelIgnoresPunctuation) { 369 FeatureProcessorOptionsT options; 370 options.context_size = 1; 371 options.max_selection_span = 1; 372 options.snap_label_span_boundaries_to_containing_tokens = false; 373 374 options.tokenization_codepoint_config.emplace_back( 375 new TokenizationCodepointRangeT()); 376 auto& config = options.tokenization_codepoint_config.back(); 377 config->start = 32; 378 config->end = 33; 379 config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; 380 381 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); 382 TestingFeatureProcessor feature_processor( 383 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), 384 &unilib_); 385 std::vector<Token> tokens = feature_processor.Tokenize("one, two, three"); 386 ASSERT_EQ(3, tokens.size()); 387 int label; 388 ASSERT_TRUE(feature_processor.SpanToLabel({5, 8}, tokens, &label)); 389 EXPECT_EQ(kInvalidLabel, label); 390 ASSERT_TRUE(feature_processor.SpanToLabel({5, 9}, tokens, &label)); 391 EXPECT_NE(kInvalidLabel, label); 392 TokenSpan token_span; 393 feature_processor.LabelToTokenSpan(label, &token_span); 394 EXPECT_EQ(0, token_span.first); 395 EXPECT_EQ(0, token_span.second); 396 397 // Reconfigure with snapping enabled. 398 options.snap_label_span_boundaries_to_containing_tokens = true; 399 flatbuffers::DetachedBuffer options2_fb = 400 PackFeatureProcessorOptions(options); 401 TestingFeatureProcessor feature_processor2( 402 flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()), 403 &unilib_); 404 int label2; 405 ASSERT_TRUE(feature_processor2.SpanToLabel({5, 8}, tokens, &label2)); 406 EXPECT_EQ(label, label2); 407 ASSERT_TRUE(feature_processor2.SpanToLabel({6, 9}, tokens, &label2)); 408 EXPECT_EQ(label, label2); 409 ASSERT_TRUE(feature_processor2.SpanToLabel({5, 9}, tokens, &label2)); 410 EXPECT_EQ(label, label2); 411 412 // Cross a token boundary. 413 ASSERT_TRUE(feature_processor2.SpanToLabel({4, 9}, tokens, &label2)); 414 EXPECT_EQ(kInvalidLabel, label2); 415 ASSERT_TRUE(feature_processor2.SpanToLabel({5, 10}, tokens, &label2)); 416 EXPECT_EQ(kInvalidLabel, label2); 417 418 // Multiple tokens. 419 options.context_size = 2; 420 options.max_selection_span = 2; 421 flatbuffers::DetachedBuffer options3_fb = 422 PackFeatureProcessorOptions(options); 423 TestingFeatureProcessor feature_processor3( 424 flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()), 425 &unilib_); 426 tokens = feature_processor3.Tokenize("zero, one, two, three, four"); 427 ASSERT_TRUE(feature_processor3.SpanToLabel({6, 15}, tokens, &label2)); 428 EXPECT_NE(kInvalidLabel, label2); 429 feature_processor3.LabelToTokenSpan(label2, &token_span); 430 EXPECT_EQ(1, token_span.first); 431 EXPECT_EQ(0, token_span.second); 432 433 int label3; 434 ASSERT_TRUE(feature_processor3.SpanToLabel({6, 14}, tokens, &label3)); 435 EXPECT_EQ(label2, label3); 436 ASSERT_TRUE(feature_processor3.SpanToLabel({6, 13}, tokens, &label3)); 437 EXPECT_EQ(label2, label3); 438 ASSERT_TRUE(feature_processor3.SpanToLabel({7, 13}, tokens, &label3)); 439 EXPECT_EQ(label2, label3); 440 } 441 442 TEST_F(FeatureProcessorTest, CenterTokenFromClick) { 443 int token_index; 444 445 // Exactly aligned indices. 446 token_index = internal::CenterTokenFromClick( 447 {6, 11}, 448 {Token("Hll", 0, 5), Token("world", 6, 11), Token("hee!", 12, 17)}); 449 EXPECT_EQ(token_index, 1); 450 451 // Click is contained in a token. 452 token_index = internal::CenterTokenFromClick( 453 {13, 17}, 454 {Token("Hll", 0, 5), Token("world", 6, 11), Token("hee!", 12, 17)}); 455 EXPECT_EQ(token_index, 2); 456 457 // Click spans two tokens. 458 token_index = internal::CenterTokenFromClick( 459 {6, 17}, 460 {Token("Hll", 0, 5), Token("world", 6, 11), Token("hee!", 12, 17)}); 461 EXPECT_EQ(token_index, kInvalidIndex); 462 } 463 464 TEST_F(FeatureProcessorTest, CenterTokenFromMiddleOfSelection) { 465 int token_index; 466 467 // Selection of length 3. Exactly aligned indices. 468 token_index = internal::CenterTokenFromMiddleOfSelection( 469 {7, 27}, 470 {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20), 471 Token("Token4", 21, 27), Token("Token5", 28, 34)}); 472 EXPECT_EQ(token_index, 2); 473 474 // Selection of length 1 token. Exactly aligned indices. 475 token_index = internal::CenterTokenFromMiddleOfSelection( 476 {21, 27}, 477 {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20), 478 Token("Token4", 21, 27), Token("Token5", 28, 34)}); 479 EXPECT_EQ(token_index, 3); 480 481 // Selection marks sub-token range, with no tokens in it. 482 token_index = internal::CenterTokenFromMiddleOfSelection( 483 {29, 33}, 484 {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20), 485 Token("Token4", 21, 27), Token("Token5", 28, 34)}); 486 EXPECT_EQ(token_index, kInvalidIndex); 487 488 // Selection of length 2. Sub-token indices. 489 token_index = internal::CenterTokenFromMiddleOfSelection( 490 {3, 25}, 491 {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20), 492 Token("Token4", 21, 27), Token("Token5", 28, 34)}); 493 EXPECT_EQ(token_index, 1); 494 495 // Selection of length 1. Sub-token indices. 496 token_index = internal::CenterTokenFromMiddleOfSelection( 497 {22, 34}, 498 {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20), 499 Token("Token4", 21, 27), Token("Token5", 28, 34)}); 500 EXPECT_EQ(token_index, 4); 501 502 // Some invalid ones. 503 token_index = internal::CenterTokenFromMiddleOfSelection({7, 27}, {}); 504 EXPECT_EQ(token_index, -1); 505 } 506 507 TEST_F(FeatureProcessorTest, SupportedCodepointsRatio) { 508 FeatureProcessorOptionsT options; 509 options.context_size = 2; 510 options.max_selection_span = 2; 511 options.snap_label_span_boundaries_to_containing_tokens = false; 512 options.feature_version = 2; 513 options.embedding_size = 4; 514 options.bounds_sensitive_features.reset( 515 new FeatureProcessorOptions_::BoundsSensitiveFeaturesT()); 516 options.bounds_sensitive_features->enabled = true; 517 options.bounds_sensitive_features->num_tokens_before = 5; 518 options.bounds_sensitive_features->num_tokens_inside_left = 3; 519 options.bounds_sensitive_features->num_tokens_inside_right = 3; 520 options.bounds_sensitive_features->num_tokens_after = 5; 521 options.bounds_sensitive_features->include_inside_bag = true; 522 options.bounds_sensitive_features->include_inside_length = true; 523 524 options.tokenization_codepoint_config.emplace_back( 525 new TokenizationCodepointRangeT()); 526 auto& config = options.tokenization_codepoint_config.back(); 527 config->start = 32; 528 config->end = 33; 529 config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR; 530 531 { 532 options.supported_codepoint_ranges.emplace_back(new CodepointRangeT()); 533 auto& range = options.supported_codepoint_ranges.back(); 534 range->start = 0; 535 range->end = 128; 536 } 537 538 { 539 options.supported_codepoint_ranges.emplace_back(new CodepointRangeT()); 540 auto& range = options.supported_codepoint_ranges.back(); 541 range->start = 10000; 542 range->end = 10001; 543 } 544 545 { 546 options.supported_codepoint_ranges.emplace_back(new CodepointRangeT()); 547 auto& range = options.supported_codepoint_ranges.back(); 548 range->start = 20000; 549 range->end = 30000; 550 } 551 552 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); 553 TestingFeatureProcessor feature_processor( 554 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), 555 &unilib_); 556 EXPECT_THAT(feature_processor.SupportedCodepointsRatio( 557 {0, 3}, feature_processor.Tokenize("aaa bbb ccc")), 558 FloatEq(1.0)); 559 EXPECT_THAT(feature_processor.SupportedCodepointsRatio( 560 {0, 3}, feature_processor.Tokenize("aaa bbb ")), 561 FloatEq(2.0 / 3)); 562 EXPECT_THAT(feature_processor.SupportedCodepointsRatio( 563 {0, 3}, feature_processor.Tokenize(" ")), 564 FloatEq(0.0)); 565 EXPECT_FALSE( 566 IsCodepointInRanges(-1, feature_processor.supported_codepoint_ranges_)); 567 EXPECT_TRUE( 568 IsCodepointInRanges(0, feature_processor.supported_codepoint_ranges_)); 569 EXPECT_TRUE( 570 IsCodepointInRanges(10, feature_processor.supported_codepoint_ranges_)); 571 EXPECT_TRUE( 572 IsCodepointInRanges(127, feature_processor.supported_codepoint_ranges_)); 573 EXPECT_FALSE( 574 IsCodepointInRanges(128, feature_processor.supported_codepoint_ranges_)); 575 EXPECT_FALSE( 576 IsCodepointInRanges(9999, feature_processor.supported_codepoint_ranges_)); 577 EXPECT_TRUE(IsCodepointInRanges( 578 10000, feature_processor.supported_codepoint_ranges_)); 579 EXPECT_FALSE(IsCodepointInRanges( 580 10001, feature_processor.supported_codepoint_ranges_)); 581 EXPECT_TRUE(IsCodepointInRanges( 582 25000, feature_processor.supported_codepoint_ranges_)); 583 584 const std::vector<Token> tokens = {Token("", 0, 3), Token("", 4, 7), 585 Token("eee", 8, 11)}; 586 587 options.min_supported_codepoint_ratio = 0.0; 588 flatbuffers::DetachedBuffer options2_fb = 589 PackFeatureProcessorOptions(options); 590 TestingFeatureProcessor feature_processor2( 591 flatbuffers::GetRoot<FeatureProcessorOptions>(options2_fb.data()), 592 &unilib_); 593 EXPECT_TRUE(feature_processor2.HasEnoughSupportedCodepoints( 594 tokens, /*token_span=*/{0, 3})); 595 596 options.min_supported_codepoint_ratio = 0.2; 597 flatbuffers::DetachedBuffer options3_fb = 598 PackFeatureProcessorOptions(options); 599 TestingFeatureProcessor feature_processor3( 600 flatbuffers::GetRoot<FeatureProcessorOptions>(options3_fb.data()), 601 &unilib_); 602 EXPECT_TRUE(feature_processor3.HasEnoughSupportedCodepoints( 603 tokens, /*token_span=*/{0, 3})); 604 605 options.min_supported_codepoint_ratio = 0.5; 606 flatbuffers::DetachedBuffer options4_fb = 607 PackFeatureProcessorOptions(options); 608 TestingFeatureProcessor feature_processor4( 609 flatbuffers::GetRoot<FeatureProcessorOptions>(options4_fb.data()), 610 &unilib_); 611 EXPECT_FALSE(feature_processor4.HasEnoughSupportedCodepoints( 612 tokens, /*token_span=*/{0, 3})); 613 } 614 615 TEST_F(FeatureProcessorTest, InSpanFeature) { 616 FeatureProcessorOptionsT options; 617 options.context_size = 2; 618 options.max_selection_span = 2; 619 options.snap_label_span_boundaries_to_containing_tokens = false; 620 options.feature_version = 2; 621 options.embedding_size = 4; 622 options.extract_selection_mask_feature = true; 623 624 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); 625 TestingFeatureProcessor feature_processor( 626 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), 627 &unilib_); 628 629 std::unique_ptr<CachedFeatures> cached_features; 630 631 FakeEmbeddingExecutor embedding_executor; 632 633 const std::vector<Token> tokens = {Token("aaa", 0, 3), Token("bbb", 4, 7), 634 Token("ccc", 8, 11), Token("ddd", 12, 15)}; 635 636 EXPECT_TRUE(feature_processor.ExtractFeatures( 637 tokens, /*token_span=*/{0, 4}, 638 /*selection_span_for_feature=*/{4, 11}, &embedding_executor, 639 /*embedding_cache=*/nullptr, /*feature_vector_size=*/5, 640 &cached_features)); 641 std::vector<float> features; 642 cached_features->AppendClickContextFeaturesForClick(1, &features); 643 ASSERT_EQ(features.size(), 25); 644 EXPECT_THAT(features[4], FloatEq(0.0)); 645 EXPECT_THAT(features[9], FloatEq(0.0)); 646 EXPECT_THAT(features[14], FloatEq(1.0)); 647 EXPECT_THAT(features[19], FloatEq(1.0)); 648 EXPECT_THAT(features[24], FloatEq(0.0)); 649 } 650 651 TEST_F(FeatureProcessorTest, EmbeddingCache) { 652 FeatureProcessorOptionsT options; 653 options.context_size = 2; 654 options.max_selection_span = 2; 655 options.snap_label_span_boundaries_to_containing_tokens = false; 656 options.feature_version = 2; 657 options.embedding_size = 4; 658 options.bounds_sensitive_features.reset( 659 new FeatureProcessorOptions_::BoundsSensitiveFeaturesT()); 660 options.bounds_sensitive_features->enabled = true; 661 options.bounds_sensitive_features->num_tokens_before = 3; 662 options.bounds_sensitive_features->num_tokens_inside_left = 2; 663 options.bounds_sensitive_features->num_tokens_inside_right = 2; 664 options.bounds_sensitive_features->num_tokens_after = 3; 665 666 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); 667 TestingFeatureProcessor feature_processor( 668 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), 669 &unilib_); 670 671 std::unique_ptr<CachedFeatures> cached_features; 672 673 FakeEmbeddingExecutor embedding_executor; 674 675 const std::vector<Token> tokens = { 676 Token("aaa", 0, 3), Token("bbb", 4, 7), Token("ccc", 8, 11), 677 Token("ddd", 12, 15), Token("eee", 16, 19), Token("fff", 20, 23)}; 678 679 // We pre-populate the cache with dummy embeddings, to make sure they are 680 // used when populating the features vector. 681 const std::vector<float> cached_padding_features = {10.0, -10.0, 10.0, -10.0}; 682 const std::vector<float> cached_features1 = {1.0, 2.0, 3.0, 4.0}; 683 const std::vector<float> cached_features2 = {5.0, 6.0, 7.0, 8.0}; 684 FeatureProcessor::EmbeddingCache embedding_cache = { 685 {{kInvalidIndex, kInvalidIndex}, cached_padding_features}, 686 {{4, 7}, cached_features1}, 687 {{12, 15}, cached_features2}, 688 }; 689 690 EXPECT_TRUE(feature_processor.ExtractFeatures( 691 tokens, /*token_span=*/{0, 6}, 692 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex}, 693 &embedding_executor, &embedding_cache, /*feature_vector_size=*/4, 694 &cached_features)); 695 std::vector<float> features; 696 cached_features->AppendBoundsSensitiveFeaturesForSpan({2, 4}, &features); 697 ASSERT_EQ(features.size(), 40); 698 // Check that the dummy embeddings were used. 699 EXPECT_THAT(Subvector(features, 0, 4), 700 ElementsAreFloat(cached_padding_features)); 701 EXPECT_THAT(Subvector(features, 8, 12), ElementsAreFloat(cached_features1)); 702 EXPECT_THAT(Subvector(features, 16, 20), ElementsAreFloat(cached_features2)); 703 EXPECT_THAT(Subvector(features, 24, 28), ElementsAreFloat(cached_features2)); 704 EXPECT_THAT(Subvector(features, 36, 40), 705 ElementsAreFloat(cached_padding_features)); 706 // Check that the real embeddings were cached. 707 EXPECT_EQ(embedding_cache.size(), 7); 708 EXPECT_THAT(Subvector(features, 4, 8), 709 ElementsAreFloat(embedding_cache.at({0, 3}))); 710 EXPECT_THAT(Subvector(features, 12, 16), 711 ElementsAreFloat(embedding_cache.at({8, 11}))); 712 EXPECT_THAT(Subvector(features, 20, 24), 713 ElementsAreFloat(embedding_cache.at({8, 11}))); 714 EXPECT_THAT(Subvector(features, 28, 32), 715 ElementsAreFloat(embedding_cache.at({16, 19}))); 716 EXPECT_THAT(Subvector(features, 32, 36), 717 ElementsAreFloat(embedding_cache.at({20, 23}))); 718 } 719 720 TEST_F(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) { 721 std::vector<Token> tokens_orig{ 722 Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0), 723 Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0), 724 Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0), 725 Token("12", 0, 0)}; 726 727 std::vector<Token> tokens; 728 int click_index; 729 730 // Try to click first token and see if it gets padded from left. 731 tokens = tokens_orig; 732 click_index = 0; 733 internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index); 734 // clang-format off 735 EXPECT_EQ(tokens, std::vector<Token>({Token(), 736 Token(), 737 Token("0", 0, 0), 738 Token("1", 0, 0), 739 Token("2", 0, 0)})); 740 // clang-format on 741 EXPECT_EQ(click_index, 2); 742 743 // When we click the second token nothing should get padded. 744 tokens = tokens_orig; 745 click_index = 2; 746 internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index); 747 // clang-format off 748 EXPECT_EQ(tokens, std::vector<Token>({Token("0", 0, 0), 749 Token("1", 0, 0), 750 Token("2", 0, 0), 751 Token("3", 0, 0), 752 Token("4", 0, 0)})); 753 // clang-format on 754 EXPECT_EQ(click_index, 2); 755 756 // When we click the last token tokens should get padded from the right. 757 tokens = tokens_orig; 758 click_index = 12; 759 internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index); 760 // clang-format off 761 EXPECT_EQ(tokens, std::vector<Token>({Token("10", 0, 0), 762 Token("11", 0, 0), 763 Token("12", 0, 0), 764 Token(), 765 Token()})); 766 // clang-format on 767 EXPECT_EQ(click_index, 2); 768 } 769 770 TEST_F(FeatureProcessorTest, StripUnusedTokensWithRelativeClick) { 771 std::vector<Token> tokens_orig{ 772 Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0), 773 Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0), 774 Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0), 775 Token("12", 0, 0)}; 776 777 std::vector<Token> tokens; 778 int click_index; 779 780 // Try to click first token and see if it gets padded from left to maximum 781 // context_size. 782 tokens = tokens_orig; 783 click_index = 0; 784 internal::StripOrPadTokens({2, 3}, 2, &tokens, &click_index); 785 // clang-format off 786 EXPECT_EQ(tokens, std::vector<Token>({Token(), 787 Token(), 788 Token("0", 0, 0), 789 Token("1", 0, 0), 790 Token("2", 0, 0), 791 Token("3", 0, 0), 792 Token("4", 0, 0), 793 Token("5", 0, 0)})); 794 // clang-format on 795 EXPECT_EQ(click_index, 2); 796 797 // Clicking to the middle with enough context should not produce any padding. 798 tokens = tokens_orig; 799 click_index = 6; 800 internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index); 801 // clang-format off 802 EXPECT_EQ(tokens, std::vector<Token>({Token("1", 0, 0), 803 Token("2", 0, 0), 804 Token("3", 0, 0), 805 Token("4", 0, 0), 806 Token("5", 0, 0), 807 Token("6", 0, 0), 808 Token("7", 0, 0), 809 Token("8", 0, 0), 810 Token("9", 0, 0)})); 811 // clang-format on 812 EXPECT_EQ(click_index, 5); 813 814 // Clicking at the end should pad right to maximum context_size. 815 tokens = tokens_orig; 816 click_index = 11; 817 internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index); 818 // clang-format off 819 EXPECT_EQ(tokens, std::vector<Token>({Token("6", 0, 0), 820 Token("7", 0, 0), 821 Token("8", 0, 0), 822 Token("9", 0, 0), 823 Token("10", 0, 0), 824 Token("11", 0, 0), 825 Token("12", 0, 0), 826 Token(), 827 Token()})); 828 // clang-format on 829 EXPECT_EQ(click_index, 5); 830 } 831 832 TEST_F(FeatureProcessorTest, IgnoredSpanBoundaryCodepoints) { 833 FeatureProcessorOptionsT options; 834 options.ignored_span_boundary_codepoints.push_back('.'); 835 options.ignored_span_boundary_codepoints.push_back(','); 836 options.ignored_span_boundary_codepoints.push_back('['); 837 options.ignored_span_boundary_codepoints.push_back(']'); 838 839 flatbuffers::DetachedBuffer options_fb = PackFeatureProcessorOptions(options); 840 TestingFeatureProcessor feature_processor( 841 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()), 842 &unilib_); 843 844 const std::string text1_utf8 = ""; 845 const UnicodeText text1 = UTF8ToUnicodeText(text1_utf8, /*do_copy=*/false); 846 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( 847 text1.begin(), text1.end(), 848 /*count_from_beginning=*/true), 849 0); 850 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( 851 text1.begin(), text1.end(), 852 /*count_from_beginning=*/false), 853 0); 854 855 const std::string text2_utf8 = ".,abd"; 856 const UnicodeText text2 = UTF8ToUnicodeText(text2_utf8, /*do_copy=*/false); 857 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( 858 text2.begin(), text2.end(), 859 /*count_from_beginning=*/true), 860 2); 861 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( 862 text2.begin(), text2.end(), 863 /*count_from_beginning=*/false), 864 0); 865 866 const std::string text3_utf8 = ".,abd[]"; 867 const UnicodeText text3 = UTF8ToUnicodeText(text3_utf8, /*do_copy=*/false); 868 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( 869 text3.begin(), text3.end(), 870 /*count_from_beginning=*/true), 871 2); 872 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( 873 text3.begin(), text3.end(), 874 /*count_from_beginning=*/false), 875 2); 876 877 const std::string text4_utf8 = "[abd]"; 878 const UnicodeText text4 = UTF8ToUnicodeText(text4_utf8, /*do_copy=*/false); 879 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( 880 text4.begin(), text4.end(), 881 /*count_from_beginning=*/true), 882 1); 883 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( 884 text4.begin(), text4.end(), 885 /*count_from_beginning=*/false), 886 1); 887 888 const std::string text5_utf8 = ""; 889 const UnicodeText text5 = UTF8ToUnicodeText(text5_utf8, /*do_copy=*/false); 890 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( 891 text5.begin(), text5.end(), 892 /*count_from_beginning=*/true), 893 0); 894 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( 895 text5.begin(), text5.end(), 896 /*count_from_beginning=*/false), 897 0); 898 899 const std::string text6_utf8 = "012345"; 900 const UnicodeText text6 = UTF8ToUnicodeText(text6_utf8, /*do_copy=*/false); 901 UnicodeText::const_iterator text6_begin = text6.begin(); 902 std::advance(text6_begin, 6); 903 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( 904 text6_begin, text6.end(), 905 /*count_from_beginning=*/true), 906 0); 907 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( 908 text6_begin, text6.end(), 909 /*count_from_beginning=*/false), 910 0); 911 912 const std::string text7_utf8 = "012345.,"; 913 const UnicodeText text7 = UTF8ToUnicodeText(text7_utf8, /*do_copy=*/false); 914 UnicodeText::const_iterator text7_begin = text7.begin(); 915 std::advance(text7_begin, 6); 916 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( 917 text7_begin, text7.end(), 918 /*count_from_beginning=*/true), 919 2); 920 UnicodeText::const_iterator text7_end = text7.begin(); 921 std::advance(text7_end, 8); 922 EXPECT_EQ(feature_processor.CountIgnoredSpanBoundaryCodepoints( 923 text7.begin(), text7_end, 924 /*count_from_beginning=*/false), 925 2); 926 927 // Test not stripping. 928 EXPECT_EQ(feature_processor.StripBoundaryCodepoints( 929 "Hello [[[Wold]] or not?", {0, 24}), 930 std::make_pair(0, 24)); 931 // Test basic stripping. 932 EXPECT_EQ(feature_processor.StripBoundaryCodepoints( 933 "Hello [[[Wold]] or not?", {6, 16}), 934 std::make_pair(9, 14)); 935 // Test stripping when everything is stripped. 936 EXPECT_EQ( 937 feature_processor.StripBoundaryCodepoints("Hello [[[]] or not?", {6, 11}), 938 std::make_pair(6, 6)); 939 // Test stripping empty string. 940 EXPECT_EQ(feature_processor.StripBoundaryCodepoints("", {0, 0}), 941 std::make_pair(0, 0)); 942 } 943 944 TEST_F(FeatureProcessorTest, CodepointSpanToTokenSpan) { 945 const std::vector<Token> tokens{Token("Hll", 0, 5), 946 Token("fba@google.com", 6, 23), 947 Token("hee!", 24, 29)}; 948 949 // Spans matching the tokens exactly. 950 EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5})); 951 EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23})); 952 EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29})); 953 EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23})); 954 EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29})); 955 EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29})); 956 957 // Snapping to containing tokens has no effect. 958 EXPECT_EQ(TokenSpan(0, 1), CodepointSpanToTokenSpan(tokens, {0, 5}, true)); 959 EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {6, 23}, true)); 960 EXPECT_EQ(TokenSpan(2, 3), CodepointSpanToTokenSpan(tokens, {24, 29}, true)); 961 EXPECT_EQ(TokenSpan(0, 2), CodepointSpanToTokenSpan(tokens, {0, 23}, true)); 962 EXPECT_EQ(TokenSpan(1, 3), CodepointSpanToTokenSpan(tokens, {6, 29}, true)); 963 EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {0, 29}, true)); 964 965 // Span boundaries inside tokens. 966 EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {1, 28})); 967 EXPECT_EQ(TokenSpan(0, 3), CodepointSpanToTokenSpan(tokens, {1, 28}, true)); 968 969 // Tokens adjacent to the span, but not overlapping. 970 EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24})); 971 EXPECT_EQ(TokenSpan(1, 2), CodepointSpanToTokenSpan(tokens, {5, 24}, true)); 972 } 973 974 } // namespace 975 } // namespace libtextclassifier3 976