1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "annotator/feature-processor.h" 18 19 #include <iterator> 20 #include <set> 21 #include <vector> 22 23 #include "utils/base/logging.h" 24 #include "utils/strings/utf8.h" 25 #include "utils/utf8/unicodetext.h" 26 27 namespace libtextclassifier3 { 28 29 namespace internal { 30 31 Tokenizer BuildTokenizer(const FeatureProcessorOptions* options, 32 const UniLib* unilib) { 33 std::vector<const TokenizationCodepointRange*> codepoint_config; 34 if (options->tokenization_codepoint_config() != nullptr) { 35 codepoint_config.insert(codepoint_config.end(), 36 options->tokenization_codepoint_config()->begin(), 37 options->tokenization_codepoint_config()->end()); 38 } 39 std::vector<const CodepointRange*> internal_codepoint_config; 40 if (options->internal_tokenizer_codepoint_ranges() != nullptr) { 41 internal_codepoint_config.insert( 42 internal_codepoint_config.end(), 43 options->internal_tokenizer_codepoint_ranges()->begin(), 44 options->internal_tokenizer_codepoint_ranges()->end()); 45 } 46 const bool tokenize_on_script_change = 47 options->tokenization_codepoint_config() != nullptr && 48 options->tokenize_on_script_change(); 49 return Tokenizer(options->tokenization_type(), unilib, codepoint_config, 50 internal_codepoint_config, tokenize_on_script_change, 51 options->icu_preserve_whitespace_tokens()); 52 } 53 54 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions( 55 const FeatureProcessorOptions* const options) { 56 TokenFeatureExtractorOptions extractor_options; 57 58 extractor_options.num_buckets = options->num_buckets(); 59 if (options->chargram_orders() != nullptr) { 60 for (int order : *options->chargram_orders()) { 61 extractor_options.chargram_orders.push_back(order); 62 } 63 } 64 extractor_options.max_word_length = options->max_word_length(); 65 extractor_options.extract_case_feature = options->extract_case_feature(); 66 extractor_options.unicode_aware_features = options->unicode_aware_features(); 67 extractor_options.extract_selection_mask_feature = 68 options->extract_selection_mask_feature(); 69 if (options->regexp_feature() != nullptr) { 70 for (const auto& regexp_feauture : *options->regexp_feature()) { 71 extractor_options.regexp_features.push_back(regexp_feauture->str()); 72 } 73 } 74 extractor_options.remap_digits = options->remap_digits(); 75 extractor_options.lowercase_tokens = options->lowercase_tokens(); 76 77 if (options->allowed_chargrams() != nullptr) { 78 for (const auto& chargram : *options->allowed_chargrams()) { 79 extractor_options.allowed_chargrams.insert(chargram->str()); 80 } 81 } 82 return extractor_options; 83 } 84 85 void SplitTokensOnSelectionBoundaries(CodepointSpan selection, 86 std::vector<Token>* tokens) { 87 for (auto it = tokens->begin(); it != tokens->end(); ++it) { 88 const UnicodeText token_word = 89 UTF8ToUnicodeText(it->value, /*do_copy=*/false); 90 91 auto last_start = token_word.begin(); 92 int last_start_index = it->start; 93 std::vector<UnicodeText::const_iterator> split_points; 94 95 // Selection start split point. 96 if (selection.first > it->start && selection.first < it->end) { 97 std::advance(last_start, selection.first - last_start_index); 98 split_points.push_back(last_start); 99 last_start_index = selection.first; 100 } 101 102 // Selection end split point. 103 if (selection.second > it->start && selection.second < it->end) { 104 std::advance(last_start, selection.second - last_start_index); 105 split_points.push_back(last_start); 106 } 107 108 if (!split_points.empty()) { 109 // Add a final split for the rest of the token unless it's been all 110 // consumed already. 111 if (split_points.back() != token_word.end()) { 112 split_points.push_back(token_word.end()); 113 } 114 115 std::vector<Token> replacement_tokens; 116 last_start = token_word.begin(); 117 int current_pos = it->start; 118 for (const auto& split_point : split_points) { 119 Token new_token(token_word.UTF8Substring(last_start, split_point), 120 current_pos, 121 current_pos + std::distance(last_start, split_point)); 122 123 last_start = split_point; 124 current_pos = new_token.end; 125 126 replacement_tokens.push_back(new_token); 127 } 128 129 it = tokens->erase(it); 130 it = tokens->insert(it, replacement_tokens.begin(), 131 replacement_tokens.end()); 132 std::advance(it, replacement_tokens.size() - 1); 133 } 134 } 135 } 136 137 } // namespace internal 138 139 void FeatureProcessor::StripTokensFromOtherLines( 140 const std::string& context, CodepointSpan span, 141 std::vector<Token>* tokens) const { 142 const UnicodeText context_unicode = UTF8ToUnicodeText(context, 143 /*do_copy=*/false); 144 StripTokensFromOtherLines(context_unicode, span, tokens); 145 } 146 147 void FeatureProcessor::StripTokensFromOtherLines( 148 const UnicodeText& context_unicode, CodepointSpan span, 149 std::vector<Token>* tokens) const { 150 std::vector<UnicodeTextRange> lines = SplitContext(context_unicode); 151 152 auto span_start = context_unicode.begin(); 153 if (span.first > 0) { 154 std::advance(span_start, span.first); 155 } 156 auto span_end = context_unicode.begin(); 157 if (span.second > 0) { 158 std::advance(span_end, span.second); 159 } 160 for (const UnicodeTextRange& line : lines) { 161 // Find the line that completely contains the span. 162 if (line.first <= span_start && line.second >= span_end) { 163 const CodepointIndex last_line_begin_index = 164 std::distance(context_unicode.begin(), line.first); 165 const CodepointIndex last_line_end_index = 166 last_line_begin_index + std::distance(line.first, line.second); 167 168 for (auto token = tokens->begin(); token != tokens->end();) { 169 if (token->start >= last_line_begin_index && 170 token->end <= last_line_end_index) { 171 ++token; 172 } else { 173 token = tokens->erase(token); 174 } 175 } 176 } 177 } 178 } 179 180 std::string FeatureProcessor::GetDefaultCollection() const { 181 if (options_->default_collection() < 0 || 182 options_->collections() == nullptr || 183 options_->default_collection() >= options_->collections()->size()) { 184 TC3_LOG(ERROR) 185 << "Invalid or missing default collection. Returning empty string."; 186 return ""; 187 } 188 return (*options_->collections())[options_->default_collection()]->str(); 189 } 190 191 std::vector<Token> FeatureProcessor::Tokenize(const std::string& text) const { 192 return tokenizer_.Tokenize(text); 193 } 194 195 std::vector<Token> FeatureProcessor::Tokenize( 196 const UnicodeText& text_unicode) const { 197 return tokenizer_.Tokenize(text_unicode); 198 } 199 200 bool FeatureProcessor::LabelToSpan( 201 const int label, const VectorSpan<Token>& tokens, 202 std::pair<CodepointIndex, CodepointIndex>* span) const { 203 if (tokens.size() != GetNumContextTokens()) { 204 return false; 205 } 206 207 TokenSpan token_span; 208 if (!LabelToTokenSpan(label, &token_span)) { 209 return false; 210 } 211 212 const int result_begin_token_index = token_span.first; 213 const Token& result_begin_token = 214 tokens[options_->context_size() - result_begin_token_index]; 215 const int result_begin_codepoint = result_begin_token.start; 216 const int result_end_token_index = token_span.second; 217 const Token& result_end_token = 218 tokens[options_->context_size() + result_end_token_index]; 219 const int result_end_codepoint = result_end_token.end; 220 221 if (result_begin_codepoint == kInvalidIndex || 222 result_end_codepoint == kInvalidIndex) { 223 *span = CodepointSpan({kInvalidIndex, kInvalidIndex}); 224 } else { 225 const UnicodeText token_begin_unicode = 226 UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false); 227 UnicodeText::const_iterator token_begin = token_begin_unicode.begin(); 228 const UnicodeText token_end_unicode = 229 UTF8ToUnicodeText(result_end_token.value, /*do_copy=*/false); 230 UnicodeText::const_iterator token_end = token_end_unicode.end(); 231 232 const int begin_ignored = CountIgnoredSpanBoundaryCodepoints( 233 token_begin, token_begin_unicode.end(), 234 /*count_from_beginning=*/true); 235 const int end_ignored = 236 CountIgnoredSpanBoundaryCodepoints(token_end_unicode.begin(), token_end, 237 /*count_from_beginning=*/false); 238 // In case everything would be stripped, set the span to the original 239 // beginning and zero length. 240 if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) { 241 *span = {result_begin_codepoint, result_begin_codepoint}; 242 } else { 243 *span = CodepointSpan({result_begin_codepoint + begin_ignored, 244 result_end_codepoint - end_ignored}); 245 } 246 } 247 return true; 248 } 249 250 bool FeatureProcessor::LabelToTokenSpan(const int label, 251 TokenSpan* token_span) const { 252 if (label >= 0 && label < label_to_selection_.size()) { 253 *token_span = label_to_selection_[label]; 254 return true; 255 } else { 256 return false; 257 } 258 } 259 260 bool FeatureProcessor::SpanToLabel( 261 const std::pair<CodepointIndex, CodepointIndex>& span, 262 const std::vector<Token>& tokens, int* label) const { 263 if (tokens.size() != GetNumContextTokens()) { 264 return false; 265 } 266 267 const int click_position = 268 options_->context_size(); // Click is always in the middle. 269 const int padding = options_->context_size() - options_->max_selection_span(); 270 271 int span_left = 0; 272 for (int i = click_position - 1; i >= padding; i--) { 273 if (tokens[i].start != kInvalidIndex && tokens[i].end > span.first) { 274 ++span_left; 275 } else { 276 break; 277 } 278 } 279 280 int span_right = 0; 281 for (int i = click_position + 1; i < tokens.size() - padding; ++i) { 282 if (tokens[i].end != kInvalidIndex && tokens[i].start < span.second) { 283 ++span_right; 284 } else { 285 break; 286 } 287 } 288 289 // Check that the spanned tokens cover the whole span. 290 bool tokens_match_span; 291 const CodepointIndex tokens_start = tokens[click_position - span_left].start; 292 const CodepointIndex tokens_end = tokens[click_position + span_right].end; 293 if (options_->snap_label_span_boundaries_to_containing_tokens()) { 294 tokens_match_span = tokens_start <= span.first && tokens_end >= span.second; 295 } else { 296 const UnicodeText token_left_unicode = UTF8ToUnicodeText( 297 tokens[click_position - span_left].value, /*do_copy=*/false); 298 const UnicodeText token_right_unicode = UTF8ToUnicodeText( 299 tokens[click_position + span_right].value, /*do_copy=*/false); 300 301 UnicodeText::const_iterator span_begin = token_left_unicode.begin(); 302 UnicodeText::const_iterator span_end = token_right_unicode.end(); 303 304 const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints( 305 span_begin, token_left_unicode.end(), /*count_from_beginning=*/true); 306 const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints( 307 token_right_unicode.begin(), span_end, 308 /*count_from_beginning=*/false); 309 310 tokens_match_span = tokens_start <= span.first && 311 tokens_start + num_punctuation_start >= span.first && 312 tokens_end >= span.second && 313 tokens_end - num_punctuation_end <= span.second; 314 } 315 316 if (tokens_match_span) { 317 *label = TokenSpanToLabel({span_left, span_right}); 318 } else { 319 *label = kInvalidLabel; 320 } 321 322 return true; 323 } 324 325 int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const { 326 auto it = selection_to_label_.find(span); 327 if (it != selection_to_label_.end()) { 328 return it->second; 329 } else { 330 return kInvalidLabel; 331 } 332 } 333 334 TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens, 335 CodepointSpan codepoint_span, 336 bool snap_boundaries_to_containing_tokens) { 337 const int codepoint_start = std::get<0>(codepoint_span); 338 const int codepoint_end = std::get<1>(codepoint_span); 339 340 TokenIndex start_token = kInvalidIndex; 341 TokenIndex end_token = kInvalidIndex; 342 for (int i = 0; i < selectable_tokens.size(); ++i) { 343 bool is_token_in_span; 344 if (snap_boundaries_to_containing_tokens) { 345 is_token_in_span = codepoint_start < selectable_tokens[i].end && 346 codepoint_end > selectable_tokens[i].start; 347 } else { 348 is_token_in_span = codepoint_start <= selectable_tokens[i].start && 349 codepoint_end >= selectable_tokens[i].end; 350 } 351 if (is_token_in_span && !selectable_tokens[i].is_padding) { 352 if (start_token == kInvalidIndex) { 353 start_token = i; 354 } 355 end_token = i + 1; 356 } 357 } 358 return {start_token, end_token}; 359 } 360 361 CodepointSpan TokenSpanToCodepointSpan( 362 const std::vector<Token>& selectable_tokens, TokenSpan token_span) { 363 return {selectable_tokens[token_span.first].start, 364 selectable_tokens[token_span.second - 1].end}; 365 } 366 367 namespace { 368 369 // Finds a single token that completely contains the given span. 370 int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens, 371 CodepointSpan codepoint_span) { 372 const int codepoint_start = std::get<0>(codepoint_span); 373 const int codepoint_end = std::get<1>(codepoint_span); 374 375 for (int i = 0; i < selectable_tokens.size(); ++i) { 376 if (codepoint_start >= selectable_tokens[i].start && 377 codepoint_end <= selectable_tokens[i].end) { 378 return i; 379 } 380 } 381 return kInvalidIndex; 382 } 383 384 } // namespace 385 386 namespace internal { 387 388 int CenterTokenFromClick(CodepointSpan span, 389 const std::vector<Token>& selectable_tokens) { 390 int range_begin; 391 int range_end; 392 std::tie(range_begin, range_end) = 393 CodepointSpanToTokenSpan(selectable_tokens, span); 394 395 // If no exact match was found, try finding a token that completely contains 396 // the click span. This is useful e.g. when Android builds the selection 397 // using ICU tokenization, and ends up with only a portion of our space- 398 // separated token. E.g. for "(857)" Android would select "857". 399 if (range_begin == kInvalidIndex || range_end == kInvalidIndex) { 400 int token_index = FindTokenThatContainsSpan(selectable_tokens, span); 401 if (token_index != kInvalidIndex) { 402 range_begin = token_index; 403 range_end = token_index + 1; 404 } 405 } 406 407 // We only allow clicks that are exactly 1 selectable token. 408 if (range_end - range_begin == 1) { 409 return range_begin; 410 } else { 411 return kInvalidIndex; 412 } 413 } 414 415 int CenterTokenFromMiddleOfSelection( 416 CodepointSpan span, const std::vector<Token>& selectable_tokens) { 417 int range_begin; 418 int range_end; 419 std::tie(range_begin, range_end) = 420 CodepointSpanToTokenSpan(selectable_tokens, span); 421 422 // Center the clicked token in the selection range. 423 if (range_begin != kInvalidIndex && range_end != kInvalidIndex) { 424 return (range_begin + range_end - 1) / 2; 425 } else { 426 return kInvalidIndex; 427 } 428 } 429 430 } // namespace internal 431 432 int FeatureProcessor::FindCenterToken(CodepointSpan span, 433 const std::vector<Token>& tokens) const { 434 if (options_->center_token_selection_method() == 435 FeatureProcessorOptions_:: 436 CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK) { 437 return internal::CenterTokenFromClick(span, tokens); 438 } else if (options_->center_token_selection_method() == 439 FeatureProcessorOptions_:: 440 CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION) { 441 return internal::CenterTokenFromMiddleOfSelection(span, tokens); 442 } else if (options_->center_token_selection_method() == 443 FeatureProcessorOptions_:: 444 CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD) { 445 // TODO(zilka): Remove once we have new models on the device. 446 // It uses the fact that sharing model use 447 // split_tokens_on_selection_boundaries and selection not. So depending on 448 // this we select the right way of finding the click location. 449 if (!options_->split_tokens_on_selection_boundaries()) { 450 // SmartSelection model. 451 return internal::CenterTokenFromClick(span, tokens); 452 } else { 453 // SmartSharing model. 454 return internal::CenterTokenFromMiddleOfSelection(span, tokens); 455 } 456 } else { 457 TC3_LOG(ERROR) << "Invalid center token selection method."; 458 return kInvalidIndex; 459 } 460 } 461 462 bool FeatureProcessor::SelectionLabelSpans( 463 const VectorSpan<Token> tokens, 464 std::vector<CodepointSpan>* selection_label_spans) const { 465 for (int i = 0; i < label_to_selection_.size(); ++i) { 466 CodepointSpan span; 467 if (!LabelToSpan(i, tokens, &span)) { 468 TC3_LOG(ERROR) << "Could not convert label to span: " << i; 469 return false; 470 } 471 selection_label_spans->push_back(span); 472 } 473 return true; 474 } 475 476 void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() { 477 if (options_->ignored_span_boundary_codepoints() != nullptr) { 478 for (const int codepoint : *options_->ignored_span_boundary_codepoints()) { 479 ignored_span_boundary_codepoints_.insert(codepoint); 480 } 481 } 482 } 483 484 int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints( 485 const UnicodeText::const_iterator& span_start, 486 const UnicodeText::const_iterator& span_end, 487 bool count_from_beginning) const { 488 if (span_start == span_end) { 489 return 0; 490 } 491 492 UnicodeText::const_iterator it; 493 UnicodeText::const_iterator it_last; 494 if (count_from_beginning) { 495 it = span_start; 496 it_last = span_end; 497 // We can assume that the string is non-zero length because of the check 498 // above, thus the decrement is always valid here. 499 --it_last; 500 } else { 501 it = span_end; 502 it_last = span_start; 503 // We can assume that the string is non-zero length because of the check 504 // above, thus the decrement is always valid here. 505 --it; 506 } 507 508 // Move until we encounter a non-ignored character. 509 int num_ignored = 0; 510 while (ignored_span_boundary_codepoints_.find(*it) != 511 ignored_span_boundary_codepoints_.end()) { 512 ++num_ignored; 513 514 if (it == it_last) { 515 break; 516 } 517 518 if (count_from_beginning) { 519 ++it; 520 } else { 521 --it; 522 } 523 } 524 525 return num_ignored; 526 } 527 528 namespace { 529 530 void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints, 531 std::vector<UnicodeTextRange>* ranges) { 532 UnicodeText::const_iterator start = t.begin(); 533 UnicodeText::const_iterator curr = start; 534 UnicodeText::const_iterator end = t.end(); 535 for (; curr != end; ++curr) { 536 if (codepoints.find(*curr) != codepoints.end()) { 537 if (start != curr) { 538 ranges->push_back(std::make_pair(start, curr)); 539 } 540 start = curr; 541 ++start; 542 } 543 } 544 if (start != end) { 545 ranges->push_back(std::make_pair(start, end)); 546 } 547 } 548 549 } // namespace 550 551 std::vector<UnicodeTextRange> FeatureProcessor::SplitContext( 552 const UnicodeText& context_unicode) const { 553 std::vector<UnicodeTextRange> lines; 554 const std::set<char32> codepoints{{'\n', '|'}}; 555 FindSubstrings(context_unicode, codepoints, &lines); 556 return lines; 557 } 558 559 CodepointSpan FeatureProcessor::StripBoundaryCodepoints( 560 const std::string& context, CodepointSpan span) const { 561 const UnicodeText context_unicode = 562 UTF8ToUnicodeText(context, /*do_copy=*/false); 563 return StripBoundaryCodepoints(context_unicode, span); 564 } 565 566 CodepointSpan FeatureProcessor::StripBoundaryCodepoints( 567 const UnicodeText& context_unicode, CodepointSpan span) const { 568 if (context_unicode.empty() || !ValidNonEmptySpan(span)) { 569 return span; 570 } 571 572 UnicodeText::const_iterator span_begin = context_unicode.begin(); 573 std::advance(span_begin, span.first); 574 UnicodeText::const_iterator span_end = context_unicode.begin(); 575 std::advance(span_end, span.second); 576 577 return StripBoundaryCodepoints(span_begin, span_end, span); 578 } 579 580 CodepointSpan FeatureProcessor::StripBoundaryCodepoints( 581 const UnicodeText::const_iterator& span_begin, 582 const UnicodeText::const_iterator& span_end, CodepointSpan span) const { 583 if (!ValidNonEmptySpan(span) || span_begin == span_end) { 584 return span; 585 } 586 587 const int start_offset = CountIgnoredSpanBoundaryCodepoints( 588 span_begin, span_end, /*count_from_beginning=*/true); 589 const int end_offset = CountIgnoredSpanBoundaryCodepoints( 590 span_begin, span_end, /*count_from_beginning=*/false); 591 592 if (span.first + start_offset < span.second - end_offset) { 593 return {span.first + start_offset, span.second - end_offset}; 594 } else { 595 return {span.first, span.first}; 596 } 597 } 598 599 float FeatureProcessor::SupportedCodepointsRatio( 600 const TokenSpan& token_span, const std::vector<Token>& tokens) const { 601 int num_supported = 0; 602 int num_total = 0; 603 for (int i = token_span.first; i < token_span.second; ++i) { 604 const UnicodeText value = 605 UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false); 606 for (auto codepoint : value) { 607 if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) { 608 ++num_supported; 609 } 610 ++num_total; 611 } 612 } 613 return static_cast<float>(num_supported) / static_cast<float>(num_total); 614 } 615 616 const std::string& FeatureProcessor::StripBoundaryCodepoints( 617 const std::string& value, std::string* buffer) const { 618 const UnicodeText value_unicode = UTF8ToUnicodeText(value, /*do_copy=*/false); 619 const CodepointSpan initial_span{0, value_unicode.size_codepoints()}; 620 const CodepointSpan stripped_span = 621 StripBoundaryCodepoints(value_unicode, initial_span); 622 623 if (initial_span != stripped_span) { 624 const UnicodeText stripped_token_value = 625 UnicodeText::Substring(value_unicode, stripped_span.first, 626 stripped_span.second, /*do_copy=*/false); 627 *buffer = stripped_token_value.ToUTF8String(); 628 return *buffer; 629 } 630 return value; 631 } 632 633 int FeatureProcessor::CollectionToLabel(const std::string& collection) const { 634 const auto it = collection_to_label_.find(collection); 635 if (it == collection_to_label_.end()) { 636 return options_->default_collection(); 637 } else { 638 return it->second; 639 } 640 } 641 642 std::string FeatureProcessor::LabelToCollection(int label) const { 643 if (label >= 0 && label < collection_to_label_.size()) { 644 return (*options_->collections())[label]->str(); 645 } else { 646 return GetDefaultCollection(); 647 } 648 } 649 650 void FeatureProcessor::MakeLabelMaps() { 651 if (options_->collections() != nullptr) { 652 for (int i = 0; i < options_->collections()->size(); ++i) { 653 collection_to_label_[(*options_->collections())[i]->str()] = i; 654 } 655 } 656 657 int selection_label_id = 0; 658 for (int l = 0; l < (options_->max_selection_span() + 1); ++l) { 659 for (int r = 0; r < (options_->max_selection_span() + 1); ++r) { 660 if (!options_->selection_reduced_output_space() || 661 r + l <= options_->max_selection_span()) { 662 TokenSpan token_span{l, r}; 663 selection_to_label_[token_span] = selection_label_id; 664 label_to_selection_.push_back(token_span); 665 ++selection_label_id; 666 } 667 } 668 } 669 } 670 671 void FeatureProcessor::RetokenizeAndFindClick(const std::string& context, 672 CodepointSpan input_span, 673 bool only_use_line_with_click, 674 std::vector<Token>* tokens, 675 int* click_pos) const { 676 const UnicodeText context_unicode = 677 UTF8ToUnicodeText(context, /*do_copy=*/false); 678 RetokenizeAndFindClick(context_unicode, input_span, only_use_line_with_click, 679 tokens, click_pos); 680 } 681 682 void FeatureProcessor::RetokenizeAndFindClick( 683 const UnicodeText& context_unicode, CodepointSpan input_span, 684 bool only_use_line_with_click, std::vector<Token>* tokens, 685 int* click_pos) const { 686 TC3_CHECK(tokens != nullptr); 687 688 if (options_->split_tokens_on_selection_boundaries()) { 689 internal::SplitTokensOnSelectionBoundaries(input_span, tokens); 690 } 691 692 if (only_use_line_with_click) { 693 StripTokensFromOtherLines(context_unicode, input_span, tokens); 694 } 695 696 int local_click_pos; 697 if (click_pos == nullptr) { 698 click_pos = &local_click_pos; 699 } 700 *click_pos = FindCenterToken(input_span, *tokens); 701 if (*click_pos == kInvalidIndex) { 702 // If the default click method failed, let's try to do sub-token matching 703 // before we fail. 704 *click_pos = internal::CenterTokenFromClick(input_span, *tokens); 705 } 706 } 707 708 namespace internal { 709 710 void StripOrPadTokens(TokenSpan relative_click_span, int context_size, 711 std::vector<Token>* tokens, int* click_pos) { 712 int right_context_needed = relative_click_span.second + context_size; 713 if (*click_pos + right_context_needed + 1 >= tokens->size()) { 714 // Pad max the context size. 715 const int num_pad_tokens = std::min( 716 context_size, static_cast<int>(*click_pos + right_context_needed + 1 - 717 tokens->size())); 718 std::vector<Token> pad_tokens(num_pad_tokens); 719 tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end()); 720 } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) { 721 // Strip unused tokens. 722 auto it = tokens->begin(); 723 std::advance(it, *click_pos + right_context_needed + 1); 724 tokens->erase(it, tokens->end()); 725 } 726 727 int left_context_needed = relative_click_span.first + context_size; 728 if (*click_pos < left_context_needed) { 729 // Pad max the context size. 730 const int num_pad_tokens = 731 std::min(context_size, left_context_needed - *click_pos); 732 std::vector<Token> pad_tokens(num_pad_tokens); 733 tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end()); 734 *click_pos += num_pad_tokens; 735 } else if (*click_pos > left_context_needed) { 736 // Strip unused tokens. 737 auto it = tokens->begin(); 738 std::advance(it, *click_pos - left_context_needed); 739 *click_pos -= it - tokens->begin(); 740 tokens->erase(tokens->begin(), it); 741 } 742 } 743 744 } // namespace internal 745 746 bool FeatureProcessor::HasEnoughSupportedCodepoints( 747 const std::vector<Token>& tokens, TokenSpan token_span) const { 748 if (options_->min_supported_codepoint_ratio() > 0) { 749 const float supported_codepoint_ratio = 750 SupportedCodepointsRatio(token_span, tokens); 751 if (supported_codepoint_ratio < options_->min_supported_codepoint_ratio()) { 752 TC3_VLOG(1) << "Not enough supported codepoints in the context: " 753 << supported_codepoint_ratio; 754 return false; 755 } 756 } 757 return true; 758 } 759 760 bool FeatureProcessor::ExtractFeatures( 761 const std::vector<Token>& tokens, TokenSpan token_span, 762 CodepointSpan selection_span_for_feature, 763 const EmbeddingExecutor* embedding_executor, 764 EmbeddingCache* embedding_cache, int feature_vector_size, 765 std::unique_ptr<CachedFeatures>* cached_features) const { 766 std::unique_ptr<std::vector<float>> features(new std::vector<float>()); 767 features->reserve(feature_vector_size * TokenSpanSize(token_span)); 768 for (int i = token_span.first; i < token_span.second; ++i) { 769 if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature, 770 embedding_executor, embedding_cache, 771 features.get())) { 772 TC3_LOG(ERROR) << "Could not get token features."; 773 return false; 774 } 775 } 776 777 std::unique_ptr<std::vector<float>> padding_features( 778 new std::vector<float>()); 779 padding_features->reserve(feature_vector_size); 780 if (!AppendTokenFeaturesWithCache(Token(), selection_span_for_feature, 781 embedding_executor, embedding_cache, 782 padding_features.get())) { 783 TC3_LOG(ERROR) << "Count not get padding token features."; 784 return false; 785 } 786 787 *cached_features = CachedFeatures::Create(token_span, std::move(features), 788 std::move(padding_features), 789 options_, feature_vector_size); 790 if (!*cached_features) { 791 TC3_LOG(ERROR) << "Cound not create cached features."; 792 return false; 793 } 794 795 return true; 796 } 797 798 bool FeatureProcessor::AppendTokenFeaturesWithCache( 799 const Token& token, CodepointSpan selection_span_for_feature, 800 const EmbeddingExecutor* embedding_executor, 801 EmbeddingCache* embedding_cache, 802 std::vector<float>* output_features) const { 803 // Look for the embedded features for the token in the cache, if there is one. 804 if (embedding_cache) { 805 const auto it = embedding_cache->find({token.start, token.end}); 806 if (it != embedding_cache->end()) { 807 // The embedded features were found in the cache, extract only the dense 808 // features. 809 std::vector<float> dense_features; 810 if (!feature_extractor_.Extract( 811 token, token.IsContainedInSpan(selection_span_for_feature), 812 /*sparse_features=*/nullptr, &dense_features)) { 813 TC3_LOG(ERROR) << "Could not extract token's dense features."; 814 return false; 815 } 816 817 // Append both embedded and dense features to the output and return. 818 output_features->insert(output_features->end(), it->second.begin(), 819 it->second.end()); 820 output_features->insert(output_features->end(), dense_features.begin(), 821 dense_features.end()); 822 return true; 823 } 824 } 825 826 // Extract the sparse and dense features. 827 std::vector<int> sparse_features; 828 std::vector<float> dense_features; 829 if (!feature_extractor_.Extract( 830 token, token.IsContainedInSpan(selection_span_for_feature), 831 &sparse_features, &dense_features)) { 832 TC3_LOG(ERROR) << "Could not extract token's features."; 833 return false; 834 } 835 836 // Embed the sparse features, appending them directly to the output. 837 const int embedding_size = GetOptions()->embedding_size(); 838 output_features->resize(output_features->size() + embedding_size); 839 float* output_features_end = 840 output_features->data() + output_features->size(); 841 if (!embedding_executor->AddEmbedding( 842 TensorView<int>(sparse_features.data(), 843 {static_cast<int>(sparse_features.size())}), 844 /*dest=*/output_features_end - embedding_size, 845 /*dest_size=*/embedding_size)) { 846 TC3_LOG(ERROR) << "Cound not embed token's sparse features."; 847 return false; 848 } 849 850 // If there is a cache, the embedded features for the token were not in it, 851 // so insert them. 852 if (embedding_cache) { 853 (*embedding_cache)[{token.start, token.end}] = std::vector<float>( 854 output_features_end - embedding_size, output_features_end); 855 } 856 857 // Append the dense features to the output. 858 output_features->insert(output_features->end(), dense_features.begin(), 859 dense_features.end()); 860 return true; 861 } 862 863 } // namespace libtextclassifier3 864