1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "text-classifier.h" 18 19 #include <algorithm> 20 #include <cctype> 21 #include <cmath> 22 #include <iterator> 23 #include <numeric> 24 25 #include "util/base/logging.h" 26 #include "util/math/softmax.h" 27 #include "util/utf8/unicodetext.h" 28 29 namespace libtextclassifier2 { 30 const std::string& TextClassifier::kOtherCollection = 31 *[]() { return new std::string("other"); }(); 32 const std::string& TextClassifier::kPhoneCollection = 33 *[]() { return new std::string("phone"); }(); 34 const std::string& TextClassifier::kAddressCollection = 35 *[]() { return new std::string("address"); }(); 36 const std::string& TextClassifier::kDateCollection = 37 *[]() { return new std::string("date"); }(); 38 39 namespace { 40 const Model* LoadAndVerifyModel(const void* addr, int size) { 41 const Model* model = GetModel(addr); 42 43 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size); 44 if (model->Verify(verifier)) { 45 return model; 46 } else { 47 return nullptr; 48 } 49 } 50 } // namespace 51 52 tflite::Interpreter* InterpreterManager::SelectionInterpreter() { 53 if (!selection_interpreter_) { 54 TC_CHECK(selection_executor_); 55 selection_interpreter_ = selection_executor_->CreateInterpreter(); 56 if (!selection_interpreter_) { 57 TC_LOG(ERROR) << "Could not build TFLite interpreter."; 58 } 59 } 60 return selection_interpreter_.get(); 61 } 62 63 tflite::Interpreter* InterpreterManager::ClassificationInterpreter() { 64 if (!classification_interpreter_) { 65 TC_CHECK(classification_executor_); 66 classification_interpreter_ = classification_executor_->CreateInterpreter(); 67 if (!classification_interpreter_) { 68 TC_LOG(ERROR) << "Could not build TFLite interpreter."; 69 } 70 } 71 return classification_interpreter_.get(); 72 } 73 74 std::unique_ptr<TextClassifier> TextClassifier::FromUnownedBuffer( 75 const char* buffer, int size, const UniLib* unilib) { 76 const Model* model = LoadAndVerifyModel(buffer, size); 77 if (model == nullptr) { 78 return nullptr; 79 } 80 81 auto classifier = 82 std::unique_ptr<TextClassifier>(new TextClassifier(model, unilib)); 83 if (!classifier->IsInitialized()) { 84 return nullptr; 85 } 86 87 return classifier; 88 } 89 90 std::unique_ptr<TextClassifier> TextClassifier::FromScopedMmap( 91 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib) { 92 if (!(*mmap)->handle().ok()) { 93 TC_VLOG(1) << "Mmap failed."; 94 return nullptr; 95 } 96 97 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(), 98 (*mmap)->handle().num_bytes()); 99 if (!model) { 100 TC_LOG(ERROR) << "Model verification failed."; 101 return nullptr; 102 } 103 104 auto classifier = 105 std::unique_ptr<TextClassifier>(new TextClassifier(mmap, model, unilib)); 106 if (!classifier->IsInitialized()) { 107 return nullptr; 108 } 109 110 return classifier; 111 } 112 113 std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor( 114 int fd, int offset, int size, const UniLib* unilib) { 115 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size)); 116 return FromScopedMmap(&mmap, unilib); 117 } 118 119 std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor( 120 int fd, const UniLib* unilib) { 121 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd)); 122 return FromScopedMmap(&mmap, unilib); 123 } 124 125 std::unique_ptr<TextClassifier> TextClassifier::FromPath( 126 const std::string& path, const UniLib* unilib) { 127 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path)); 128 return FromScopedMmap(&mmap, unilib); 129 } 130 131 void TextClassifier::ValidateAndInitialize() { 132 initialized_ = false; 133 134 if (model_ == nullptr) { 135 TC_LOG(ERROR) << "No model specified."; 136 return; 137 } 138 139 const bool model_enabled_for_annotation = 140 (model_->triggering_options() != nullptr && 141 (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)); 142 const bool model_enabled_for_classification = 143 (model_->triggering_options() != nullptr && 144 (model_->triggering_options()->enabled_modes() & 145 ModeFlag_CLASSIFICATION)); 146 const bool model_enabled_for_selection = 147 (model_->triggering_options() != nullptr && 148 (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)); 149 150 // Annotation requires the selection model. 151 if (model_enabled_for_annotation || model_enabled_for_selection) { 152 if (!model_->selection_options()) { 153 TC_LOG(ERROR) << "No selection options."; 154 return; 155 } 156 if (!model_->selection_feature_options()) { 157 TC_LOG(ERROR) << "No selection feature options."; 158 return; 159 } 160 if (!model_->selection_feature_options()->bounds_sensitive_features()) { 161 TC_LOG(ERROR) << "No selection bounds sensitive feature options."; 162 return; 163 } 164 if (!model_->selection_model()) { 165 TC_LOG(ERROR) << "No selection model."; 166 return; 167 } 168 selection_executor_ = ModelExecutor::Instance(model_->selection_model()); 169 if (!selection_executor_) { 170 TC_LOG(ERROR) << "Could not initialize selection executor."; 171 return; 172 } 173 selection_feature_processor_.reset( 174 new FeatureProcessor(model_->selection_feature_options(), unilib_)); 175 } 176 177 // Annotation requires the classification model for conflict resolution and 178 // scoring. 179 // Selection requires the classification model for conflict resolution. 180 if (model_enabled_for_annotation || model_enabled_for_classification || 181 model_enabled_for_selection) { 182 if (!model_->classification_options()) { 183 TC_LOG(ERROR) << "No classification options."; 184 return; 185 } 186 187 if (!model_->classification_feature_options()) { 188 TC_LOG(ERROR) << "No classification feature options."; 189 return; 190 } 191 192 if (!model_->classification_feature_options() 193 ->bounds_sensitive_features()) { 194 TC_LOG(ERROR) << "No classification bounds sensitive feature options."; 195 return; 196 } 197 if (!model_->classification_model()) { 198 TC_LOG(ERROR) << "No clf model."; 199 return; 200 } 201 202 classification_executor_ = 203 ModelExecutor::Instance(model_->classification_model()); 204 if (!classification_executor_) { 205 TC_LOG(ERROR) << "Could not initialize classification executor."; 206 return; 207 } 208 209 classification_feature_processor_.reset(new FeatureProcessor( 210 model_->classification_feature_options(), unilib_)); 211 } 212 213 // The embeddings need to be specified if the model is to be used for 214 // classification or selection. 215 if (model_enabled_for_annotation || model_enabled_for_classification || 216 model_enabled_for_selection) { 217 if (!model_->embedding_model()) { 218 TC_LOG(ERROR) << "No embedding model."; 219 return; 220 } 221 222 // Check that the embedding size of the selection and classification model 223 // matches, as they are using the same embeddings. 224 if (model_enabled_for_selection && 225 (model_->selection_feature_options()->embedding_size() != 226 model_->classification_feature_options()->embedding_size() || 227 model_->selection_feature_options()->embedding_quantization_bits() != 228 model_->classification_feature_options() 229 ->embedding_quantization_bits())) { 230 TC_LOG(ERROR) << "Mismatching embedding size/quantization."; 231 return; 232 } 233 234 embedding_executor_ = TFLiteEmbeddingExecutor::Instance( 235 model_->embedding_model(), 236 model_->classification_feature_options()->embedding_size(), 237 model_->classification_feature_options() 238 ->embedding_quantization_bits()); 239 if (!embedding_executor_) { 240 TC_LOG(ERROR) << "Could not initialize embedding executor."; 241 return; 242 } 243 } 244 245 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(); 246 if (model_->regex_model()) { 247 if (!InitializeRegexModel(decompressor.get())) { 248 TC_LOG(ERROR) << "Could not initialize regex model."; 249 return; 250 } 251 } 252 253 if (model_->datetime_model()) { 254 datetime_parser_ = DatetimeParser::Instance(model_->datetime_model(), 255 *unilib_, decompressor.get()); 256 if (!datetime_parser_) { 257 TC_LOG(ERROR) << "Could not initialize datetime parser."; 258 return; 259 } 260 } 261 262 if (model_->output_options()) { 263 if (model_->output_options()->filtered_collections_annotation()) { 264 for (const auto collection : 265 *model_->output_options()->filtered_collections_annotation()) { 266 filtered_collections_annotation_.insert(collection->str()); 267 } 268 } 269 if (model_->output_options()->filtered_collections_classification()) { 270 for (const auto collection : 271 *model_->output_options()->filtered_collections_classification()) { 272 filtered_collections_classification_.insert(collection->str()); 273 } 274 } 275 if (model_->output_options()->filtered_collections_selection()) { 276 for (const auto collection : 277 *model_->output_options()->filtered_collections_selection()) { 278 filtered_collections_selection_.insert(collection->str()); 279 } 280 } 281 } 282 283 initialized_ = true; 284 } 285 286 bool TextClassifier::InitializeRegexModel(ZlibDecompressor* decompressor) { 287 if (!model_->regex_model()->patterns()) { 288 return true; 289 } 290 291 // Initialize pattern recognizers. 292 int regex_pattern_id = 0; 293 for (const auto& regex_pattern : *model_->regex_model()->patterns()) { 294 std::unique_ptr<UniLib::RegexPattern> compiled_pattern = 295 UncompressMakeRegexPattern(*unilib_, regex_pattern->pattern(), 296 regex_pattern->compressed_pattern(), 297 decompressor); 298 if (!compiled_pattern) { 299 TC_LOG(INFO) << "Failed to load regex pattern"; 300 return false; 301 } 302 303 if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) { 304 annotation_regex_patterns_.push_back(regex_pattern_id); 305 } 306 if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) { 307 classification_regex_patterns_.push_back(regex_pattern_id); 308 } 309 if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) { 310 selection_regex_patterns_.push_back(regex_pattern_id); 311 } 312 regex_patterns_.push_back({regex_pattern->collection_name()->str(), 313 regex_pattern->target_classification_score(), 314 regex_pattern->priority_score(), 315 std::move(compiled_pattern)}); 316 if (regex_pattern->use_approximate_matching()) { 317 regex_approximate_match_pattern_ids_.insert(regex_pattern_id); 318 } 319 ++regex_pattern_id; 320 } 321 322 return true; 323 } 324 325 namespace { 326 327 int CountDigits(const std::string& str, CodepointSpan selection_indices) { 328 int count = 0; 329 int i = 0; 330 const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false); 331 for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) { 332 if (i >= selection_indices.first && i < selection_indices.second && 333 isdigit(*it)) { 334 ++count; 335 } 336 } 337 return count; 338 } 339 340 std::string ExtractSelection(const std::string& context, 341 CodepointSpan selection_indices) { 342 const UnicodeText context_unicode = 343 UTF8ToUnicodeText(context, /*do_copy=*/false); 344 auto selection_begin = context_unicode.begin(); 345 std::advance(selection_begin, selection_indices.first); 346 auto selection_end = context_unicode.begin(); 347 std::advance(selection_end, selection_indices.second); 348 return UnicodeText::UTF8Substring(selection_begin, selection_end); 349 } 350 } // namespace 351 352 namespace internal { 353 // Helper function, which if the initial 'span' contains only white-spaces, 354 // moves the selection to a single-codepoint selection on a left or right side 355 // of this space. 356 CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span, 357 const UnicodeText& context_unicode, 358 const UniLib& unilib) { 359 TC_CHECK(ValidNonEmptySpan(span)); 360 361 UnicodeText::const_iterator it; 362 363 // Check that the current selection is all whitespaces. 364 it = context_unicode.begin(); 365 std::advance(it, span.first); 366 for (int i = 0; i < (span.second - span.first); ++i, ++it) { 367 if (!unilib.IsWhitespace(*it)) { 368 return span; 369 } 370 } 371 372 CodepointSpan result; 373 374 // Try moving left. 375 result = span; 376 it = context_unicode.begin(); 377 std::advance(it, span.first); 378 while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) { 379 --result.first; 380 --it; 381 } 382 result.second = result.first + 1; 383 if (!unilib.IsWhitespace(*it)) { 384 return result; 385 } 386 387 // If moving left didn't find a non-whitespace character, just return the 388 // original span. 389 return span; 390 } 391 } // namespace internal 392 393 bool TextClassifier::FilteredForAnnotation(const AnnotatedSpan& span) const { 394 return !span.classification.empty() && 395 filtered_collections_annotation_.find( 396 span.classification[0].collection) != 397 filtered_collections_annotation_.end(); 398 } 399 400 bool TextClassifier::FilteredForClassification( 401 const ClassificationResult& classification) const { 402 return filtered_collections_classification_.find(classification.collection) != 403 filtered_collections_classification_.end(); 404 } 405 406 bool TextClassifier::FilteredForSelection(const AnnotatedSpan& span) const { 407 return !span.classification.empty() && 408 filtered_collections_selection_.find( 409 span.classification[0].collection) != 410 filtered_collections_selection_.end(); 411 } 412 413 CodepointSpan TextClassifier::SuggestSelection( 414 const std::string& context, CodepointSpan click_indices, 415 const SelectionOptions& options) const { 416 CodepointSpan original_click_indices = click_indices; 417 if (!initialized_) { 418 TC_LOG(ERROR) << "Not initialized"; 419 return original_click_indices; 420 } 421 if (!(model_->enabled_modes() & ModeFlag_SELECTION)) { 422 return original_click_indices; 423 } 424 425 const UnicodeText context_unicode = UTF8ToUnicodeText(context, 426 /*do_copy=*/false); 427 428 if (!context_unicode.is_valid()) { 429 return original_click_indices; 430 } 431 432 const int context_codepoint_size = context_unicode.size_codepoints(); 433 434 if (click_indices.first < 0 || click_indices.second < 0 || 435 click_indices.first >= context_codepoint_size || 436 click_indices.second > context_codepoint_size || 437 click_indices.first >= click_indices.second) { 438 TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices: " 439 << click_indices.first << " " << click_indices.second; 440 return original_click_indices; 441 } 442 443 if (model_->snap_whitespace_selections()) { 444 // We want to expand a purely white-space selection to a multi-selection it 445 // would've been part of. But with this feature disabled we would do a no- 446 // op, because no token is found. Therefore, we need to modify the 447 // 'click_indices' a bit to include a part of the token, so that the click- 448 // finding logic finds the clicked token correctly. This modification is 449 // done by the following function. Note, that it's enough to check the left 450 // side of the current selection, because if the white-space is a part of a 451 // multi-selection, neccessarily both tokens - on the left and the right 452 // sides need to be selected. Thus snapping only to the left is sufficient 453 // (there's a check at the bottom that makes sure that if we snap to the 454 // left token but the result does not contain the initial white-space, 455 // returns the original indices). 456 click_indices = internal::SnapLeftIfWhitespaceSelection( 457 click_indices, context_unicode, *unilib_); 458 } 459 460 std::vector<AnnotatedSpan> candidates; 461 InterpreterManager interpreter_manager(selection_executor_.get(), 462 classification_executor_.get()); 463 std::vector<Token> tokens; 464 if (!ModelSuggestSelection(context_unicode, click_indices, 465 &interpreter_manager, &tokens, &candidates)) { 466 TC_LOG(ERROR) << "Model suggest selection failed."; 467 return original_click_indices; 468 } 469 if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates)) { 470 TC_LOG(ERROR) << "Regex suggest selection failed."; 471 return original_click_indices; 472 } 473 if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), 474 /*reference_time_ms_utc=*/0, /*reference_timezone=*/"", 475 options.locales, ModeFlag_SELECTION, &candidates)) { 476 TC_LOG(ERROR) << "Datetime suggest selection failed."; 477 return original_click_indices; 478 } 479 480 // Sort candidates according to their position in the input, so that the next 481 // code can assume that any connected component of overlapping spans forms a 482 // contiguous block. 483 std::sort(candidates.begin(), candidates.end(), 484 [](const AnnotatedSpan& a, const AnnotatedSpan& b) { 485 return a.span.first < b.span.first; 486 }); 487 488 std::vector<int> candidate_indices; 489 if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager, 490 &candidate_indices)) { 491 TC_LOG(ERROR) << "Couldn't resolve conflicts."; 492 return original_click_indices; 493 } 494 495 for (const int i : candidate_indices) { 496 if (SpansOverlap(candidates[i].span, click_indices) && 497 SpansOverlap(candidates[i].span, original_click_indices)) { 498 // Run model classification if not present but requested and there's a 499 // classification collection filter specified. 500 if (candidates[i].classification.empty() && 501 model_->selection_options()->always_classify_suggested_selection() && 502 !filtered_collections_selection_.empty()) { 503 if (!ModelClassifyText( 504 context, candidates[i].span, &interpreter_manager, 505 /*embedding_cache=*/nullptr, &candidates[i].classification)) { 506 return original_click_indices; 507 } 508 } 509 510 // Ignore if span classification is filtered. 511 if (FilteredForSelection(candidates[i])) { 512 return original_click_indices; 513 } 514 515 return candidates[i].span; 516 } 517 } 518 519 return original_click_indices; 520 } 521 522 namespace { 523 // Helper function that returns the index of the first candidate that 524 // transitively does not overlap with the candidate on 'start_index'. If the end 525 // of 'candidates' is reached, it returns the index that points right behind the 526 // array. 527 int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates, 528 int start_index) { 529 int first_non_overlapping = start_index + 1; 530 CodepointSpan conflicting_span = candidates[start_index].span; 531 while ( 532 first_non_overlapping < candidates.size() && 533 SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) { 534 // Grow the span to include the current one. 535 conflicting_span.second = std::max( 536 conflicting_span.second, candidates[first_non_overlapping].span.second); 537 538 ++first_non_overlapping; 539 } 540 return first_non_overlapping; 541 } 542 } // namespace 543 544 bool TextClassifier::ResolveConflicts( 545 const std::vector<AnnotatedSpan>& candidates, const std::string& context, 546 const std::vector<Token>& cached_tokens, 547 InterpreterManager* interpreter_manager, std::vector<int>* result) const { 548 result->clear(); 549 result->reserve(candidates.size()); 550 for (int i = 0; i < candidates.size();) { 551 int first_non_overlapping = 552 FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i); 553 554 const bool conflict_found = first_non_overlapping != (i + 1); 555 if (conflict_found) { 556 std::vector<int> candidate_indices; 557 if (!ResolveConflict(context, cached_tokens, candidates, i, 558 first_non_overlapping, interpreter_manager, 559 &candidate_indices)) { 560 return false; 561 } 562 result->insert(result->end(), candidate_indices.begin(), 563 candidate_indices.end()); 564 } else { 565 result->push_back(i); 566 } 567 568 // Skip over the whole conflicting group/go to next candidate. 569 i = first_non_overlapping; 570 } 571 return true; 572 } 573 574 namespace { 575 inline bool ClassifiedAsOther( 576 const std::vector<ClassificationResult>& classification) { 577 return !classification.empty() && 578 classification[0].collection == TextClassifier::kOtherCollection; 579 } 580 581 float GetPriorityScore( 582 const std::vector<ClassificationResult>& classification) { 583 if (!ClassifiedAsOther(classification)) { 584 return classification[0].priority_score; 585 } else { 586 return -1.0; 587 } 588 } 589 } // namespace 590 591 bool TextClassifier::ResolveConflict( 592 const std::string& context, const std::vector<Token>& cached_tokens, 593 const std::vector<AnnotatedSpan>& candidates, int start_index, 594 int end_index, InterpreterManager* interpreter_manager, 595 std::vector<int>* chosen_indices) const { 596 std::vector<int> conflicting_indices; 597 std::unordered_map<int, float> scores; 598 for (int i = start_index; i < end_index; ++i) { 599 conflicting_indices.push_back(i); 600 if (!candidates[i].classification.empty()) { 601 scores[i] = GetPriorityScore(candidates[i].classification); 602 continue; 603 } 604 605 // OPTIMIZATION: So that we don't have to classify all the ML model 606 // spans apriori, we wait until we get here, when they conflict with 607 // something and we need the actual classification scores. So if the 608 // candidate conflicts and comes from the model, we need to run a 609 // classification to determine its priority: 610 std::vector<ClassificationResult> classification; 611 if (!ModelClassifyText(context, cached_tokens, candidates[i].span, 612 interpreter_manager, 613 /*embedding_cache=*/nullptr, &classification)) { 614 return false; 615 } 616 617 if (!classification.empty()) { 618 scores[i] = GetPriorityScore(classification); 619 } 620 } 621 622 std::sort(conflicting_indices.begin(), conflicting_indices.end(), 623 [&scores](int i, int j) { return scores[i] > scores[j]; }); 624 625 // Keeps the candidates sorted by their position in the text (their left span 626 // index) for fast retrieval down. 627 std::set<int, std::function<bool(int, int)>> chosen_indices_set( 628 [&candidates](int a, int b) { 629 return candidates[a].span.first < candidates[b].span.first; 630 }); 631 632 // Greedily place the candidates if they don't conflict with the already 633 // placed ones. 634 for (int i = 0; i < conflicting_indices.size(); ++i) { 635 const int considered_candidate = conflicting_indices[i]; 636 if (!DoesCandidateConflict(considered_candidate, candidates, 637 chosen_indices_set)) { 638 chosen_indices_set.insert(considered_candidate); 639 } 640 } 641 642 *chosen_indices = 643 std::vector<int>(chosen_indices_set.begin(), chosen_indices_set.end()); 644 645 return true; 646 } 647 648 bool TextClassifier::ModelSuggestSelection( 649 const UnicodeText& context_unicode, CodepointSpan click_indices, 650 InterpreterManager* interpreter_manager, std::vector<Token>* tokens, 651 std::vector<AnnotatedSpan>* result) const { 652 if (model_->triggering_options() == nullptr || 653 !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) { 654 return true; 655 } 656 657 int click_pos; 658 *tokens = selection_feature_processor_->Tokenize(context_unicode); 659 selection_feature_processor_->RetokenizeAndFindClick( 660 context_unicode, click_indices, 661 selection_feature_processor_->GetOptions()->only_use_line_with_click(), 662 tokens, &click_pos); 663 if (click_pos == kInvalidIndex) { 664 TC_VLOG(1) << "Could not calculate the click position."; 665 return false; 666 } 667 668 const int symmetry_context_size = 669 model_->selection_options()->symmetry_context_size(); 670 const FeatureProcessorOptions_::BoundsSensitiveFeatures* 671 bounds_sensitive_features = selection_feature_processor_->GetOptions() 672 ->bounds_sensitive_features(); 673 674 // The symmetry context span is the clicked token with symmetry_context_size 675 // tokens on either side. 676 const TokenSpan symmetry_context_span = IntersectTokenSpans( 677 ExpandTokenSpan(SingleTokenSpan(click_pos), 678 /*num_tokens_left=*/symmetry_context_size, 679 /*num_tokens_right=*/symmetry_context_size), 680 {0, tokens->size()}); 681 682 // Compute the extraction span based on the model type. 683 TokenSpan extraction_span; 684 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { 685 // The extraction span is the symmetry context span expanded to include 686 // max_selection_span tokens on either side, which is how far a selection 687 // can stretch from the click, plus a relevant number of tokens outside of 688 // the bounds of the selection. 689 const int max_selection_span = 690 selection_feature_processor_->GetOptions()->max_selection_span(); 691 extraction_span = 692 ExpandTokenSpan(symmetry_context_span, 693 /*num_tokens_left=*/max_selection_span + 694 bounds_sensitive_features->num_tokens_before(), 695 /*num_tokens_right=*/max_selection_span + 696 bounds_sensitive_features->num_tokens_after()); 697 } else { 698 // The extraction span is the symmetry context span expanded to include 699 // context_size tokens on either side. 700 const int context_size = 701 selection_feature_processor_->GetOptions()->context_size(); 702 extraction_span = ExpandTokenSpan(symmetry_context_span, 703 /*num_tokens_left=*/context_size, 704 /*num_tokens_right=*/context_size); 705 } 706 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()}); 707 708 if (!selection_feature_processor_->HasEnoughSupportedCodepoints( 709 *tokens, extraction_span)) { 710 return true; 711 } 712 713 std::unique_ptr<CachedFeatures> cached_features; 714 if (!selection_feature_processor_->ExtractFeatures( 715 *tokens, extraction_span, 716 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex}, 717 embedding_executor_.get(), 718 /*embedding_cache=*/nullptr, 719 selection_feature_processor_->EmbeddingSize() + 720 selection_feature_processor_->DenseFeaturesCount(), 721 &cached_features)) { 722 TC_LOG(ERROR) << "Could not extract features."; 723 return false; 724 } 725 726 // Produce selection model candidates. 727 std::vector<TokenSpan> chunks; 728 if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span, 729 interpreter_manager->SelectionInterpreter(), *cached_features, 730 &chunks)) { 731 TC_LOG(ERROR) << "Could not chunk."; 732 return false; 733 } 734 735 for (const TokenSpan& chunk : chunks) { 736 AnnotatedSpan candidate; 737 candidate.span = selection_feature_processor_->StripBoundaryCodepoints( 738 context_unicode, TokenSpanToCodepointSpan(*tokens, chunk)); 739 if (model_->selection_options()->strip_unpaired_brackets()) { 740 candidate.span = 741 StripUnpairedBrackets(context_unicode, candidate.span, *unilib_); 742 } 743 744 // Only output non-empty spans. 745 if (candidate.span.first != candidate.span.second) { 746 result->push_back(candidate); 747 } 748 } 749 return true; 750 } 751 752 bool TextClassifier::ModelClassifyText( 753 const std::string& context, CodepointSpan selection_indices, 754 InterpreterManager* interpreter_manager, 755 FeatureProcessor::EmbeddingCache* embedding_cache, 756 std::vector<ClassificationResult>* classification_results) const { 757 if (model_->triggering_options() == nullptr || 758 !(model_->triggering_options()->enabled_modes() & 759 ModeFlag_CLASSIFICATION)) { 760 return true; 761 } 762 return ModelClassifyText(context, {}, selection_indices, interpreter_manager, 763 embedding_cache, classification_results); 764 } 765 766 namespace internal { 767 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens, 768 CodepointSpan selection_indices, 769 TokenSpan tokens_around_selection_to_copy) { 770 const auto first_selection_token = std::upper_bound( 771 cached_tokens.begin(), cached_tokens.end(), selection_indices.first, 772 [](int selection_start, const Token& token) { 773 return selection_start < token.end; 774 }); 775 const auto last_selection_token = std::lower_bound( 776 cached_tokens.begin(), cached_tokens.end(), selection_indices.second, 777 [](const Token& token, int selection_end) { 778 return token.start < selection_end; 779 }); 780 781 const int64 first_token = std::max( 782 static_cast<int64>(0), 783 static_cast<int64>((first_selection_token - cached_tokens.begin()) - 784 tokens_around_selection_to_copy.first)); 785 const int64 last_token = std::min( 786 static_cast<int64>(cached_tokens.size()), 787 static_cast<int64>((last_selection_token - cached_tokens.begin()) + 788 tokens_around_selection_to_copy.second)); 789 790 std::vector<Token> tokens; 791 tokens.reserve(last_token - first_token); 792 for (int i = first_token; i < last_token; ++i) { 793 tokens.push_back(cached_tokens[i]); 794 } 795 return tokens; 796 } 797 } // namespace internal 798 799 TokenSpan TextClassifier::ClassifyTextUpperBoundNeededTokens() const { 800 const FeatureProcessorOptions_::BoundsSensitiveFeatures* 801 bounds_sensitive_features = 802 classification_feature_processor_->GetOptions() 803 ->bounds_sensitive_features(); 804 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { 805 // The extraction span is the selection span expanded to include a relevant 806 // number of tokens outside of the bounds of the selection. 807 return {bounds_sensitive_features->num_tokens_before(), 808 bounds_sensitive_features->num_tokens_after()}; 809 } else { 810 // The extraction span is the clicked token with context_size tokens on 811 // either side. 812 const int context_size = 813 selection_feature_processor_->GetOptions()->context_size(); 814 return {context_size, context_size}; 815 } 816 } 817 818 bool TextClassifier::ModelClassifyText( 819 const std::string& context, const std::vector<Token>& cached_tokens, 820 CodepointSpan selection_indices, InterpreterManager* interpreter_manager, 821 FeatureProcessor::EmbeddingCache* embedding_cache, 822 std::vector<ClassificationResult>* classification_results) const { 823 std::vector<Token> tokens; 824 if (cached_tokens.empty()) { 825 tokens = classification_feature_processor_->Tokenize(context); 826 } else { 827 tokens = internal::CopyCachedTokens(cached_tokens, selection_indices, 828 ClassifyTextUpperBoundNeededTokens()); 829 } 830 831 int click_pos; 832 classification_feature_processor_->RetokenizeAndFindClick( 833 context, selection_indices, 834 classification_feature_processor_->GetOptions() 835 ->only_use_line_with_click(), 836 &tokens, &click_pos); 837 const TokenSpan selection_token_span = 838 CodepointSpanToTokenSpan(tokens, selection_indices); 839 const int selection_num_tokens = TokenSpanSize(selection_token_span); 840 if (model_->classification_options()->max_num_tokens() > 0 && 841 model_->classification_options()->max_num_tokens() < 842 selection_num_tokens) { 843 *classification_results = {{kOtherCollection, 1.0}}; 844 return true; 845 } 846 847 const FeatureProcessorOptions_::BoundsSensitiveFeatures* 848 bounds_sensitive_features = 849 classification_feature_processor_->GetOptions() 850 ->bounds_sensitive_features(); 851 if (selection_token_span.first == kInvalidIndex || 852 selection_token_span.second == kInvalidIndex) { 853 TC_LOG(ERROR) << "Could not determine span."; 854 return false; 855 } 856 857 // Compute the extraction span based on the model type. 858 TokenSpan extraction_span; 859 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { 860 // The extraction span is the selection span expanded to include a relevant 861 // number of tokens outside of the bounds of the selection. 862 extraction_span = ExpandTokenSpan( 863 selection_token_span, 864 /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(), 865 /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after()); 866 } else { 867 if (click_pos == kInvalidIndex) { 868 TC_LOG(ERROR) << "Couldn't choose a click position."; 869 return false; 870 } 871 // The extraction span is the clicked token with context_size tokens on 872 // either side. 873 const int context_size = 874 classification_feature_processor_->GetOptions()->context_size(); 875 extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos), 876 /*num_tokens_left=*/context_size, 877 /*num_tokens_right=*/context_size); 878 } 879 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens.size()}); 880 881 if (!classification_feature_processor_->HasEnoughSupportedCodepoints( 882 tokens, extraction_span)) { 883 *classification_results = {{kOtherCollection, 1.0}}; 884 return true; 885 } 886 887 std::unique_ptr<CachedFeatures> cached_features; 888 if (!classification_feature_processor_->ExtractFeatures( 889 tokens, extraction_span, selection_indices, embedding_executor_.get(), 890 embedding_cache, 891 classification_feature_processor_->EmbeddingSize() + 892 classification_feature_processor_->DenseFeaturesCount(), 893 &cached_features)) { 894 TC_LOG(ERROR) << "Could not extract features."; 895 return false; 896 } 897 898 std::vector<float> features; 899 features.reserve(cached_features->OutputFeaturesSize()); 900 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { 901 cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span, 902 &features); 903 } else { 904 cached_features->AppendClickContextFeaturesForClick(click_pos, &features); 905 } 906 907 TensorView<float> logits = classification_executor_->ComputeLogits( 908 TensorView<float>(features.data(), 909 {1, static_cast<int>(features.size())}), 910 interpreter_manager->ClassificationInterpreter()); 911 if (!logits.is_valid()) { 912 TC_LOG(ERROR) << "Couldn't compute logits."; 913 return false; 914 } 915 916 if (logits.dims() != 2 || logits.dim(0) != 1 || 917 logits.dim(1) != classification_feature_processor_->NumCollections()) { 918 TC_LOG(ERROR) << "Mismatching output"; 919 return false; 920 } 921 922 const std::vector<float> scores = 923 ComputeSoftmax(logits.data(), logits.dim(1)); 924 925 classification_results->resize(scores.size()); 926 for (int i = 0; i < scores.size(); i++) { 927 (*classification_results)[i] = { 928 classification_feature_processor_->LabelToCollection(i), scores[i]}; 929 } 930 std::sort(classification_results->begin(), classification_results->end(), 931 [](const ClassificationResult& a, const ClassificationResult& b) { 932 return a.score > b.score; 933 }); 934 935 // Phone class sanity check. 936 if (!classification_results->empty() && 937 classification_results->begin()->collection == kPhoneCollection) { 938 const int digit_count = CountDigits(context, selection_indices); 939 if (digit_count < 940 model_->classification_options()->phone_min_num_digits() || 941 digit_count > 942 model_->classification_options()->phone_max_num_digits()) { 943 *classification_results = {{kOtherCollection, 1.0}}; 944 } 945 } 946 947 // Address class sanity check. 948 if (!classification_results->empty() && 949 classification_results->begin()->collection == kAddressCollection) { 950 if (selection_num_tokens < 951 model_->classification_options()->address_min_num_tokens()) { 952 *classification_results = {{kOtherCollection, 1.0}}; 953 } 954 } 955 956 return true; 957 } 958 959 bool TextClassifier::RegexClassifyText( 960 const std::string& context, CodepointSpan selection_indices, 961 ClassificationResult* classification_result) const { 962 const std::string selection_text = 963 ExtractSelection(context, selection_indices); 964 const UnicodeText selection_text_unicode( 965 UTF8ToUnicodeText(selection_text, /*do_copy=*/false)); 966 967 // Check whether any of the regular expressions match. 968 for (const int pattern_id : classification_regex_patterns_) { 969 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id]; 970 const std::unique_ptr<UniLib::RegexMatcher> matcher = 971 regex_pattern.pattern->Matcher(selection_text_unicode); 972 int status = UniLib::RegexMatcher::kNoError; 973 bool matches; 974 if (regex_approximate_match_pattern_ids_.find(pattern_id) != 975 regex_approximate_match_pattern_ids_.end()) { 976 matches = matcher->ApproximatelyMatches(&status); 977 } else { 978 matches = matcher->Matches(&status); 979 } 980 if (status != UniLib::RegexMatcher::kNoError) { 981 return false; 982 } 983 if (matches) { 984 *classification_result = {regex_pattern.collection_name, 985 regex_pattern.target_classification_score, 986 regex_pattern.priority_score}; 987 return true; 988 } 989 if (status != UniLib::RegexMatcher::kNoError) { 990 TC_LOG(ERROR) << "Cound't match regex: " << pattern_id; 991 } 992 } 993 994 return false; 995 } 996 997 bool TextClassifier::DatetimeClassifyText( 998 const std::string& context, CodepointSpan selection_indices, 999 const ClassificationOptions& options, 1000 ClassificationResult* classification_result) const { 1001 if (!datetime_parser_) { 1002 return false; 1003 } 1004 1005 const std::string selection_text = 1006 ExtractSelection(context, selection_indices); 1007 1008 std::vector<DatetimeParseResultSpan> datetime_spans; 1009 if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc, 1010 options.reference_timezone, options.locales, 1011 ModeFlag_CLASSIFICATION, 1012 /*anchor_start_end=*/true, &datetime_spans)) { 1013 TC_LOG(ERROR) << "Error during parsing datetime."; 1014 return false; 1015 } 1016 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) { 1017 // Only consider the result valid if the selection and extracted datetime 1018 // spans exactly match. 1019 if (std::make_pair(datetime_span.span.first + selection_indices.first, 1020 datetime_span.span.second + selection_indices.first) == 1021 selection_indices) { 1022 *classification_result = {kDateCollection, 1023 datetime_span.target_classification_score}; 1024 classification_result->datetime_parse_result = datetime_span.data; 1025 return true; 1026 } 1027 } 1028 return false; 1029 } 1030 1031 std::vector<ClassificationResult> TextClassifier::ClassifyText( 1032 const std::string& context, CodepointSpan selection_indices, 1033 const ClassificationOptions& options) const { 1034 if (!initialized_) { 1035 TC_LOG(ERROR) << "Not initialized"; 1036 return {}; 1037 } 1038 1039 if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) { 1040 return {}; 1041 } 1042 1043 if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) { 1044 return {}; 1045 } 1046 1047 if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) { 1048 TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: " 1049 << std::get<0>(selection_indices) << " " 1050 << std::get<1>(selection_indices); 1051 return {}; 1052 } 1053 1054 // Try the regular expression models. 1055 ClassificationResult regex_result; 1056 if (RegexClassifyText(context, selection_indices, ®ex_result)) { 1057 if (!FilteredForClassification(regex_result)) { 1058 return {regex_result}; 1059 } else { 1060 return {{kOtherCollection, 1.0}}; 1061 } 1062 } 1063 1064 // Try the date model. 1065 ClassificationResult datetime_result; 1066 if (DatetimeClassifyText(context, selection_indices, options, 1067 &datetime_result)) { 1068 if (!FilteredForClassification(datetime_result)) { 1069 return {datetime_result}; 1070 } else { 1071 return {{kOtherCollection, 1.0}}; 1072 } 1073 } 1074 1075 // Fallback to the model. 1076 std::vector<ClassificationResult> model_result; 1077 1078 InterpreterManager interpreter_manager(selection_executor_.get(), 1079 classification_executor_.get()); 1080 if (ModelClassifyText(context, selection_indices, &interpreter_manager, 1081 /*embedding_cache=*/nullptr, &model_result) && 1082 !model_result.empty()) { 1083 if (!FilteredForClassification(model_result[0])) { 1084 return model_result; 1085 } else { 1086 return {{kOtherCollection, 1.0}}; 1087 } 1088 } 1089 1090 // No classifications. 1091 return {}; 1092 } 1093 1094 bool TextClassifier::ModelAnnotate(const std::string& context, 1095 InterpreterManager* interpreter_manager, 1096 std::vector<Token>* tokens, 1097 std::vector<AnnotatedSpan>* result) const { 1098 if (model_->triggering_options() == nullptr || 1099 !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) { 1100 return true; 1101 } 1102 1103 const UnicodeText context_unicode = UTF8ToUnicodeText(context, 1104 /*do_copy=*/false); 1105 std::vector<UnicodeTextRange> lines; 1106 if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) { 1107 lines.push_back({context_unicode.begin(), context_unicode.end()}); 1108 } else { 1109 lines = selection_feature_processor_->SplitContext(context_unicode); 1110 } 1111 1112 const float min_annotate_confidence = 1113 (model_->triggering_options() != nullptr 1114 ? model_->triggering_options()->min_annotate_confidence() 1115 : 0.f); 1116 1117 FeatureProcessor::EmbeddingCache embedding_cache; 1118 for (const UnicodeTextRange& line : lines) { 1119 const std::string line_str = 1120 UnicodeText::UTF8Substring(line.first, line.second); 1121 1122 *tokens = selection_feature_processor_->Tokenize(line_str); 1123 selection_feature_processor_->RetokenizeAndFindClick( 1124 line_str, {0, std::distance(line.first, line.second)}, 1125 selection_feature_processor_->GetOptions()->only_use_line_with_click(), 1126 tokens, 1127 /*click_pos=*/nullptr); 1128 const TokenSpan full_line_span = {0, tokens->size()}; 1129 1130 // TODO(zilka): Add support for greater granularity of this check. 1131 if (!selection_feature_processor_->HasEnoughSupportedCodepoints( 1132 *tokens, full_line_span)) { 1133 continue; 1134 } 1135 1136 std::unique_ptr<CachedFeatures> cached_features; 1137 if (!selection_feature_processor_->ExtractFeatures( 1138 *tokens, full_line_span, 1139 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex}, 1140 embedding_executor_.get(), 1141 /*embedding_cache=*/nullptr, 1142 selection_feature_processor_->EmbeddingSize() + 1143 selection_feature_processor_->DenseFeaturesCount(), 1144 &cached_features)) { 1145 TC_LOG(ERROR) << "Could not extract features."; 1146 return false; 1147 } 1148 1149 std::vector<TokenSpan> local_chunks; 1150 if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span, 1151 interpreter_manager->SelectionInterpreter(), 1152 *cached_features, &local_chunks)) { 1153 TC_LOG(ERROR) << "Could not chunk."; 1154 return false; 1155 } 1156 1157 const int offset = std::distance(context_unicode.begin(), line.first); 1158 for (const TokenSpan& chunk : local_chunks) { 1159 const CodepointSpan codepoint_span = 1160 selection_feature_processor_->StripBoundaryCodepoints( 1161 line_str, TokenSpanToCodepointSpan(*tokens, chunk)); 1162 1163 // Skip empty spans. 1164 if (codepoint_span.first != codepoint_span.second) { 1165 std::vector<ClassificationResult> classification; 1166 if (!ModelClassifyText(line_str, *tokens, codepoint_span, 1167 interpreter_manager, &embedding_cache, 1168 &classification)) { 1169 TC_LOG(ERROR) << "Could not classify text: " 1170 << (codepoint_span.first + offset) << " " 1171 << (codepoint_span.second + offset); 1172 return false; 1173 } 1174 1175 // Do not include the span if it's classified as "other". 1176 if (!classification.empty() && !ClassifiedAsOther(classification) && 1177 classification[0].score >= min_annotate_confidence) { 1178 AnnotatedSpan result_span; 1179 result_span.span = {codepoint_span.first + offset, 1180 codepoint_span.second + offset}; 1181 result_span.classification = std::move(classification); 1182 result->push_back(std::move(result_span)); 1183 } 1184 } 1185 } 1186 } 1187 return true; 1188 } 1189 1190 const FeatureProcessor* TextClassifier::SelectionFeatureProcessorForTests() 1191 const { 1192 return selection_feature_processor_.get(); 1193 } 1194 1195 const FeatureProcessor* TextClassifier::ClassificationFeatureProcessorForTests() 1196 const { 1197 return classification_feature_processor_.get(); 1198 } 1199 1200 const DatetimeParser* TextClassifier::DatetimeParserForTests() const { 1201 return datetime_parser_.get(); 1202 } 1203 1204 std::vector<AnnotatedSpan> TextClassifier::Annotate( 1205 const std::string& context, const AnnotationOptions& options) const { 1206 std::vector<AnnotatedSpan> candidates; 1207 1208 if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) { 1209 return {}; 1210 } 1211 1212 if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) { 1213 return {}; 1214 } 1215 1216 InterpreterManager interpreter_manager(selection_executor_.get(), 1217 classification_executor_.get()); 1218 // Annotate with the selection model. 1219 std::vector<Token> tokens; 1220 if (!ModelAnnotate(context, &interpreter_manager, &tokens, &candidates)) { 1221 TC_LOG(ERROR) << "Couldn't run ModelAnnotate."; 1222 return {}; 1223 } 1224 1225 // Annotate with the regular expression models. 1226 if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), 1227 annotation_regex_patterns_, &candidates)) { 1228 TC_LOG(ERROR) << "Couldn't run RegexChunk."; 1229 return {}; 1230 } 1231 1232 // Annotate with the datetime model. 1233 if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), 1234 options.reference_time_ms_utc, options.reference_timezone, 1235 options.locales, ModeFlag_ANNOTATION, &candidates)) { 1236 TC_LOG(ERROR) << "Couldn't run RegexChunk."; 1237 return {}; 1238 } 1239 1240 // Sort candidates according to their position in the input, so that the next 1241 // code can assume that any connected component of overlapping spans forms a 1242 // contiguous block. 1243 std::sort(candidates.begin(), candidates.end(), 1244 [](const AnnotatedSpan& a, const AnnotatedSpan& b) { 1245 return a.span.first < b.span.first; 1246 }); 1247 1248 std::vector<int> candidate_indices; 1249 if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager, 1250 &candidate_indices)) { 1251 TC_LOG(ERROR) << "Couldn't resolve conflicts."; 1252 return {}; 1253 } 1254 1255 std::vector<AnnotatedSpan> result; 1256 result.reserve(candidate_indices.size()); 1257 for (const int i : candidate_indices) { 1258 if (!candidates[i].classification.empty() && 1259 !ClassifiedAsOther(candidates[i].classification) && 1260 !FilteredForAnnotation(candidates[i])) { 1261 result.push_back(std::move(candidates[i])); 1262 } 1263 } 1264 1265 return result; 1266 } 1267 1268 bool TextClassifier::RegexChunk(const UnicodeText& context_unicode, 1269 const std::vector<int>& rules, 1270 std::vector<AnnotatedSpan>* result) const { 1271 for (int pattern_id : rules) { 1272 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id]; 1273 const auto matcher = regex_pattern.pattern->Matcher(context_unicode); 1274 if (!matcher) { 1275 TC_LOG(ERROR) << "Could not get regex matcher for pattern: " 1276 << pattern_id; 1277 return false; 1278 } 1279 1280 int status = UniLib::RegexMatcher::kNoError; 1281 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { 1282 result->emplace_back(); 1283 // Selection/annotation regular expressions need to specify a capturing 1284 // group specifying the selection. 1285 result->back().span = {matcher->Start(1, &status), 1286 matcher->End(1, &status)}; 1287 result->back().classification = { 1288 {regex_pattern.collection_name, 1289 regex_pattern.target_classification_score, 1290 regex_pattern.priority_score}}; 1291 } 1292 } 1293 return true; 1294 } 1295 1296 bool TextClassifier::ModelChunk(int num_tokens, 1297 const TokenSpan& span_of_interest, 1298 tflite::Interpreter* selection_interpreter, 1299 const CachedFeatures& cached_features, 1300 std::vector<TokenSpan>* chunks) const { 1301 const int max_selection_span = 1302 selection_feature_processor_->GetOptions()->max_selection_span(); 1303 // The inference span is the span of interest expanded to include 1304 // max_selection_span tokens on either side, which is how far a selection can 1305 // stretch from the click. 1306 const TokenSpan inference_span = IntersectTokenSpans( 1307 ExpandTokenSpan(span_of_interest, 1308 /*num_tokens_left=*/max_selection_span, 1309 /*num_tokens_right=*/max_selection_span), 1310 {0, num_tokens}); 1311 1312 std::vector<ScoredChunk> scored_chunks; 1313 if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() && 1314 selection_feature_processor_->GetOptions() 1315 ->bounds_sensitive_features() 1316 ->enabled()) { 1317 if (!ModelBoundsSensitiveScoreChunks( 1318 num_tokens, span_of_interest, inference_span, cached_features, 1319 selection_interpreter, &scored_chunks)) { 1320 return false; 1321 } 1322 } else { 1323 if (!ModelClickContextScoreChunks(num_tokens, span_of_interest, 1324 cached_features, selection_interpreter, 1325 &scored_chunks)) { 1326 return false; 1327 } 1328 } 1329 std::sort(scored_chunks.rbegin(), scored_chunks.rend(), 1330 [](const ScoredChunk& lhs, const ScoredChunk& rhs) { 1331 return lhs.score < rhs.score; 1332 }); 1333 1334 // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick 1335 // them greedily as long as they do not overlap with any previously picked 1336 // chunks. 1337 std::vector<bool> token_used(TokenSpanSize(inference_span)); 1338 chunks->clear(); 1339 for (const ScoredChunk& scored_chunk : scored_chunks) { 1340 bool feasible = true; 1341 for (int i = scored_chunk.token_span.first; 1342 i < scored_chunk.token_span.second; ++i) { 1343 if (token_used[i - inference_span.first]) { 1344 feasible = false; 1345 break; 1346 } 1347 } 1348 1349 if (!feasible) { 1350 continue; 1351 } 1352 1353 for (int i = scored_chunk.token_span.first; 1354 i < scored_chunk.token_span.second; ++i) { 1355 token_used[i - inference_span.first] = true; 1356 } 1357 1358 chunks->push_back(scored_chunk.token_span); 1359 } 1360 1361 std::sort(chunks->begin(), chunks->end()); 1362 1363 return true; 1364 } 1365 1366 namespace { 1367 // Updates the value at the given key in the map to maximum of the current value 1368 // and the given value, or simply inserts the value if the key is not yet there. 1369 template <typename Map> 1370 void UpdateMax(Map* map, typename Map::key_type key, 1371 typename Map::mapped_type value) { 1372 const auto it = map->find(key); 1373 if (it != map->end()) { 1374 it->second = std::max(it->second, value); 1375 } else { 1376 (*map)[key] = value; 1377 } 1378 } 1379 } // namespace 1380 1381 bool TextClassifier::ModelClickContextScoreChunks( 1382 int num_tokens, const TokenSpan& span_of_interest, 1383 const CachedFeatures& cached_features, 1384 tflite::Interpreter* selection_interpreter, 1385 std::vector<ScoredChunk>* scored_chunks) const { 1386 const int max_batch_size = model_->selection_options()->batch_size(); 1387 1388 std::vector<float> all_features; 1389 std::map<TokenSpan, float> chunk_scores; 1390 for (int batch_start = span_of_interest.first; 1391 batch_start < span_of_interest.second; batch_start += max_batch_size) { 1392 const int batch_end = 1393 std::min(batch_start + max_batch_size, span_of_interest.second); 1394 1395 // Prepare features for the whole batch. 1396 all_features.clear(); 1397 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize()); 1398 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) { 1399 cached_features.AppendClickContextFeaturesForClick(click_pos, 1400 &all_features); 1401 } 1402 1403 // Run batched inference. 1404 const int batch_size = batch_end - batch_start; 1405 const int features_size = cached_features.OutputFeaturesSize(); 1406 TensorView<float> logits = selection_executor_->ComputeLogits( 1407 TensorView<float>(all_features.data(), {batch_size, features_size}), 1408 selection_interpreter); 1409 if (!logits.is_valid()) { 1410 TC_LOG(ERROR) << "Couldn't compute logits."; 1411 return false; 1412 } 1413 if (logits.dims() != 2 || logits.dim(0) != batch_size || 1414 logits.dim(1) != 1415 selection_feature_processor_->GetSelectionLabelCount()) { 1416 TC_LOG(ERROR) << "Mismatching output."; 1417 return false; 1418 } 1419 1420 // Save results. 1421 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) { 1422 const std::vector<float> scores = ComputeSoftmax( 1423 logits.data() + logits.dim(1) * (click_pos - batch_start), 1424 logits.dim(1)); 1425 for (int j = 0; 1426 j < selection_feature_processor_->GetSelectionLabelCount(); ++j) { 1427 TokenSpan relative_token_span; 1428 if (!selection_feature_processor_->LabelToTokenSpan( 1429 j, &relative_token_span)) { 1430 TC_LOG(ERROR) << "Couldn't map the label to a token span."; 1431 return false; 1432 } 1433 const TokenSpan candidate_span = ExpandTokenSpan( 1434 SingleTokenSpan(click_pos), relative_token_span.first, 1435 relative_token_span.second); 1436 if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) { 1437 UpdateMax(&chunk_scores, candidate_span, scores[j]); 1438 } 1439 } 1440 } 1441 } 1442 1443 scored_chunks->clear(); 1444 scored_chunks->reserve(chunk_scores.size()); 1445 for (const auto& entry : chunk_scores) { 1446 scored_chunks->push_back(ScoredChunk{entry.first, entry.second}); 1447 } 1448 1449 return true; 1450 } 1451 1452 bool TextClassifier::ModelBoundsSensitiveScoreChunks( 1453 int num_tokens, const TokenSpan& span_of_interest, 1454 const TokenSpan& inference_span, const CachedFeatures& cached_features, 1455 tflite::Interpreter* selection_interpreter, 1456 std::vector<ScoredChunk>* scored_chunks) const { 1457 const int max_selection_span = 1458 selection_feature_processor_->GetOptions()->max_selection_span(); 1459 const int max_chunk_length = selection_feature_processor_->GetOptions() 1460 ->selection_reduced_output_space() 1461 ? max_selection_span + 1 1462 : 2 * max_selection_span + 1; 1463 const bool score_single_token_spans_as_zero = 1464 selection_feature_processor_->GetOptions() 1465 ->bounds_sensitive_features() 1466 ->score_single_token_spans_as_zero(); 1467 1468 scored_chunks->clear(); 1469 if (score_single_token_spans_as_zero) { 1470 scored_chunks->reserve(TokenSpanSize(span_of_interest)); 1471 } 1472 1473 // Prepare all chunk candidates into one batch: 1474 // - Are contained in the inference span 1475 // - Have a non-empty intersection with the span of interest 1476 // - Are at least one token long 1477 // - Are not longer than the maximum chunk length 1478 std::vector<TokenSpan> candidate_spans; 1479 for (int start = inference_span.first; start < span_of_interest.second; 1480 ++start) { 1481 const int leftmost_end_index = std::max(start, span_of_interest.first) + 1; 1482 for (int end = leftmost_end_index; 1483 end <= inference_span.second && end - start <= max_chunk_length; 1484 ++end) { 1485 const TokenSpan candidate_span = {start, end}; 1486 if (score_single_token_spans_as_zero && 1487 TokenSpanSize(candidate_span) == 1) { 1488 // Do not include the single token span in the batch, add a zero score 1489 // for it directly to the output. 1490 scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f}); 1491 } else { 1492 candidate_spans.push_back(candidate_span); 1493 } 1494 } 1495 } 1496 1497 const int max_batch_size = model_->selection_options()->batch_size(); 1498 1499 std::vector<float> all_features; 1500 scored_chunks->reserve(scored_chunks->size() + candidate_spans.size()); 1501 for (int batch_start = 0; batch_start < candidate_spans.size(); 1502 batch_start += max_batch_size) { 1503 const int batch_end = std::min(batch_start + max_batch_size, 1504 static_cast<int>(candidate_spans.size())); 1505 1506 // Prepare features for the whole batch. 1507 all_features.clear(); 1508 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize()); 1509 for (int i = batch_start; i < batch_end; ++i) { 1510 cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i], 1511 &all_features); 1512 } 1513 1514 // Run batched inference. 1515 const int batch_size = batch_end - batch_start; 1516 const int features_size = cached_features.OutputFeaturesSize(); 1517 TensorView<float> logits = selection_executor_->ComputeLogits( 1518 TensorView<float>(all_features.data(), {batch_size, features_size}), 1519 selection_interpreter); 1520 if (!logits.is_valid()) { 1521 TC_LOG(ERROR) << "Couldn't compute logits."; 1522 return false; 1523 } 1524 if (logits.dims() != 2 || logits.dim(0) != batch_size || 1525 logits.dim(1) != 1) { 1526 TC_LOG(ERROR) << "Mismatching output."; 1527 return false; 1528 } 1529 1530 // Save results. 1531 for (int i = batch_start; i < batch_end; ++i) { 1532 scored_chunks->push_back( 1533 ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]}); 1534 } 1535 } 1536 1537 return true; 1538 } 1539 1540 bool TextClassifier::DatetimeChunk(const UnicodeText& context_unicode, 1541 int64 reference_time_ms_utc, 1542 const std::string& reference_timezone, 1543 const std::string& locales, ModeFlag mode, 1544 std::vector<AnnotatedSpan>* result) const { 1545 if (!datetime_parser_) { 1546 return true; 1547 } 1548 1549 std::vector<DatetimeParseResultSpan> datetime_spans; 1550 if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc, 1551 reference_timezone, locales, mode, 1552 /*anchor_start_end=*/false, &datetime_spans)) { 1553 return false; 1554 } 1555 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) { 1556 AnnotatedSpan annotated_span; 1557 annotated_span.span = datetime_span.span; 1558 annotated_span.classification = {{kDateCollection, 1559 datetime_span.target_classification_score, 1560 datetime_span.priority_score}}; 1561 annotated_span.classification[0].datetime_parse_result = datetime_span.data; 1562 1563 result->push_back(std::move(annotated_span)); 1564 } 1565 return true; 1566 } 1567 1568 const Model* ViewModel(const void* buffer, int size) { 1569 if (!buffer) { 1570 return nullptr; 1571 } 1572 1573 return LoadAndVerifyModel(buffer, size); 1574 } 1575 1576 } // namespace libtextclassifier2 1577