1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "annotator/annotator.h" 18 19 #include <algorithm> 20 #include <cctype> 21 #include <cmath> 22 #include <iterator> 23 #include <numeric> 24 #include <unordered_map> 25 26 #include "annotator/collections.h" 27 #include "annotator/model_generated.h" 28 #include "annotator/types.h" 29 #include "utils/base/logging.h" 30 #include "utils/checksum.h" 31 #include "utils/math/softmax.h" 32 #include "utils/regex-match.h" 33 #include "utils/utf8/unicodetext.h" 34 #include "utils/zlib/zlib_regex.h" 35 36 37 namespace libtextclassifier3 { 38 39 using SortedIntSet = std::set<int, std::function<bool(int, int)>>; 40 41 const std::string& Annotator::kPhoneCollection = 42 *[]() { return new std::string("phone"); }(); 43 const std::string& Annotator::kAddressCollection = 44 *[]() { return new std::string("address"); }(); 45 const std::string& Annotator::kDateCollection = 46 *[]() { return new std::string("date"); }(); 47 const std::string& Annotator::kUrlCollection = 48 *[]() { return new std::string("url"); }(); 49 const std::string& Annotator::kEmailCollection = 50 *[]() { return new std::string("email"); }(); 51 52 namespace { 53 const Model* LoadAndVerifyModel(const void* addr, int size) { 54 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size); 55 if (VerifyModelBuffer(verifier)) { 56 return GetModel(addr); 57 } else { 58 return nullptr; 59 } 60 } 61 62 // If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will 63 // create a new instance, assign ownership to owned_lib, and return it. 64 const UniLib* MaybeCreateUnilib(const UniLib* lib, 65 std::unique_ptr<UniLib>* owned_lib) { 66 if (lib) { 67 return lib; 68 } else { 69 owned_lib->reset(new UniLib); 70 return owned_lib->get(); 71 } 72 } 73 74 // As above, but for CalendarLib. 75 const CalendarLib* MaybeCreateCalendarlib( 76 const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) { 77 if (lib) { 78 return lib; 79 } else { 80 owned_lib->reset(new CalendarLib); 81 return owned_lib->get(); 82 } 83 } 84 85 } // namespace 86 87 tflite::Interpreter* InterpreterManager::SelectionInterpreter() { 88 if (!selection_interpreter_) { 89 TC3_CHECK(selection_executor_); 90 selection_interpreter_ = selection_executor_->CreateInterpreter(); 91 if (!selection_interpreter_) { 92 TC3_LOG(ERROR) << "Could not build TFLite interpreter."; 93 } 94 } 95 return selection_interpreter_.get(); 96 } 97 98 tflite::Interpreter* InterpreterManager::ClassificationInterpreter() { 99 if (!classification_interpreter_) { 100 TC3_CHECK(classification_executor_); 101 classification_interpreter_ = classification_executor_->CreateInterpreter(); 102 if (!classification_interpreter_) { 103 TC3_LOG(ERROR) << "Could not build TFLite interpreter."; 104 } 105 } 106 return classification_interpreter_.get(); 107 } 108 109 std::unique_ptr<Annotator> Annotator::FromUnownedBuffer( 110 const char* buffer, int size, const UniLib* unilib, 111 const CalendarLib* calendarlib) { 112 const Model* model = LoadAndVerifyModel(buffer, size); 113 if (model == nullptr) { 114 return nullptr; 115 } 116 117 auto classifier = 118 std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib)); 119 if (!classifier->IsInitialized()) { 120 return nullptr; 121 } 122 123 return classifier; 124 } 125 126 127 std::unique_ptr<Annotator> Annotator::FromScopedMmap( 128 std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib, 129 const CalendarLib* calendarlib) { 130 if (!(*mmap)->handle().ok()) { 131 TC3_VLOG(1) << "Mmap failed."; 132 return nullptr; 133 } 134 135 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(), 136 (*mmap)->handle().num_bytes()); 137 if (!model) { 138 TC3_LOG(ERROR) << "Model verification failed."; 139 return nullptr; 140 } 141 142 auto classifier = std::unique_ptr<Annotator>( 143 new Annotator(mmap, model, unilib, calendarlib)); 144 if (!classifier->IsInitialized()) { 145 return nullptr; 146 } 147 148 return classifier; 149 } 150 151 std::unique_ptr<Annotator> Annotator::FromScopedMmap( 152 std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib, 153 std::unique_ptr<CalendarLib> calendarlib) { 154 if (!(*mmap)->handle().ok()) { 155 TC3_VLOG(1) << "Mmap failed."; 156 return nullptr; 157 } 158 159 const Model* model = LoadAndVerifyModel((*mmap)->handle().start(), 160 (*mmap)->handle().num_bytes()); 161 if (model == nullptr) { 162 TC3_LOG(ERROR) << "Model verification failed."; 163 return nullptr; 164 } 165 166 auto classifier = std::unique_ptr<Annotator>( 167 new Annotator(mmap, model, std::move(unilib), std::move(calendarlib))); 168 if (!classifier->IsInitialized()) { 169 return nullptr; 170 } 171 172 return classifier; 173 } 174 175 std::unique_ptr<Annotator> Annotator::FromFileDescriptor( 176 int fd, int offset, int size, const UniLib* unilib, 177 const CalendarLib* calendarlib) { 178 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size)); 179 return FromScopedMmap(&mmap, unilib, calendarlib); 180 } 181 182 std::unique_ptr<Annotator> Annotator::FromFileDescriptor( 183 int fd, int offset, int size, std::unique_ptr<UniLib> unilib, 184 std::unique_ptr<CalendarLib> calendarlib) { 185 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size)); 186 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib)); 187 } 188 189 std::unique_ptr<Annotator> Annotator::FromFileDescriptor( 190 int fd, const UniLib* unilib, const CalendarLib* calendarlib) { 191 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd)); 192 return FromScopedMmap(&mmap, unilib, calendarlib); 193 } 194 195 std::unique_ptr<Annotator> Annotator::FromFileDescriptor( 196 int fd, std::unique_ptr<UniLib> unilib, 197 std::unique_ptr<CalendarLib> calendarlib) { 198 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd)); 199 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib)); 200 } 201 202 std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path, 203 const UniLib* unilib, 204 const CalendarLib* calendarlib) { 205 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path)); 206 return FromScopedMmap(&mmap, unilib, calendarlib); 207 } 208 209 std::unique_ptr<Annotator> Annotator::FromPath( 210 const std::string& path, std::unique_ptr<UniLib> unilib, 211 std::unique_ptr<CalendarLib> calendarlib) { 212 std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path)); 213 return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib)); 214 } 215 216 Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model, 217 const UniLib* unilib, const CalendarLib* calendarlib) 218 : model_(model), 219 mmap_(std::move(*mmap)), 220 owned_unilib_(nullptr), 221 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)), 222 owned_calendarlib_(nullptr), 223 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) { 224 ValidateAndInitialize(); 225 } 226 227 Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model, 228 std::unique_ptr<UniLib> unilib, 229 std::unique_ptr<CalendarLib> calendarlib) 230 : model_(model), 231 mmap_(std::move(*mmap)), 232 owned_unilib_(std::move(unilib)), 233 unilib_(owned_unilib_.get()), 234 owned_calendarlib_(std::move(calendarlib)), 235 calendarlib_(owned_calendarlib_.get()) { 236 ValidateAndInitialize(); 237 } 238 239 Annotator::Annotator(const Model* model, const UniLib* unilib, 240 const CalendarLib* calendarlib) 241 : model_(model), 242 owned_unilib_(nullptr), 243 unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)), 244 owned_calendarlib_(nullptr), 245 calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) { 246 ValidateAndInitialize(); 247 } 248 249 void Annotator::ValidateAndInitialize() { 250 initialized_ = false; 251 252 if (model_ == nullptr) { 253 TC3_LOG(ERROR) << "No model specified."; 254 return; 255 } 256 257 const bool model_enabled_for_annotation = 258 (model_->triggering_options() != nullptr && 259 (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)); 260 const bool model_enabled_for_classification = 261 (model_->triggering_options() != nullptr && 262 (model_->triggering_options()->enabled_modes() & 263 ModeFlag_CLASSIFICATION)); 264 const bool model_enabled_for_selection = 265 (model_->triggering_options() != nullptr && 266 (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)); 267 268 // Annotation requires the selection model. 269 if (model_enabled_for_annotation || model_enabled_for_selection) { 270 if (!model_->selection_options()) { 271 TC3_LOG(ERROR) << "No selection options."; 272 return; 273 } 274 if (!model_->selection_feature_options()) { 275 TC3_LOG(ERROR) << "No selection feature options."; 276 return; 277 } 278 if (!model_->selection_feature_options()->bounds_sensitive_features()) { 279 TC3_LOG(ERROR) << "No selection bounds sensitive feature options."; 280 return; 281 } 282 if (!model_->selection_model()) { 283 TC3_LOG(ERROR) << "No selection model."; 284 return; 285 } 286 selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model()); 287 if (!selection_executor_) { 288 TC3_LOG(ERROR) << "Could not initialize selection executor."; 289 return; 290 } 291 selection_feature_processor_.reset( 292 new FeatureProcessor(model_->selection_feature_options(), unilib_)); 293 } 294 295 // Annotation requires the classification model for conflict resolution and 296 // scoring. 297 // Selection requires the classification model for conflict resolution. 298 if (model_enabled_for_annotation || model_enabled_for_classification || 299 model_enabled_for_selection) { 300 if (!model_->classification_options()) { 301 TC3_LOG(ERROR) << "No classification options."; 302 return; 303 } 304 305 if (!model_->classification_feature_options()) { 306 TC3_LOG(ERROR) << "No classification feature options."; 307 return; 308 } 309 310 if (!model_->classification_feature_options() 311 ->bounds_sensitive_features()) { 312 TC3_LOG(ERROR) << "No classification bounds sensitive feature options."; 313 return; 314 } 315 if (!model_->classification_model()) { 316 TC3_LOG(ERROR) << "No clf model."; 317 return; 318 } 319 320 classification_executor_ = 321 ModelExecutor::FromBuffer(model_->classification_model()); 322 if (!classification_executor_) { 323 TC3_LOG(ERROR) << "Could not initialize classification executor."; 324 return; 325 } 326 327 classification_feature_processor_.reset(new FeatureProcessor( 328 model_->classification_feature_options(), unilib_)); 329 } 330 331 // The embeddings need to be specified if the model is to be used for 332 // classification or selection. 333 if (model_enabled_for_annotation || model_enabled_for_classification || 334 model_enabled_for_selection) { 335 if (!model_->embedding_model()) { 336 TC3_LOG(ERROR) << "No embedding model."; 337 return; 338 } 339 340 // Check that the embedding size of the selection and classification model 341 // matches, as they are using the same embeddings. 342 if (model_enabled_for_selection && 343 (model_->selection_feature_options()->embedding_size() != 344 model_->classification_feature_options()->embedding_size() || 345 model_->selection_feature_options()->embedding_quantization_bits() != 346 model_->classification_feature_options() 347 ->embedding_quantization_bits())) { 348 TC3_LOG(ERROR) << "Mismatching embedding size/quantization."; 349 return; 350 } 351 352 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer( 353 model_->embedding_model(), 354 model_->classification_feature_options()->embedding_size(), 355 model_->classification_feature_options()->embedding_quantization_bits(), 356 model_->embedding_pruning_mask()); 357 if (!embedding_executor_) { 358 TC3_LOG(ERROR) << "Could not initialize embedding executor."; 359 return; 360 } 361 } 362 363 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(); 364 if (model_->regex_model()) { 365 if (!InitializeRegexModel(decompressor.get())) { 366 TC3_LOG(ERROR) << "Could not initialize regex model."; 367 return; 368 } 369 } 370 371 if (model_->datetime_model()) { 372 datetime_parser_ = DatetimeParser::Instance( 373 model_->datetime_model(), *unilib_, *calendarlib_, decompressor.get()); 374 if (!datetime_parser_) { 375 TC3_LOG(ERROR) << "Could not initialize datetime parser."; 376 return; 377 } 378 } 379 380 if (model_->output_options()) { 381 if (model_->output_options()->filtered_collections_annotation()) { 382 for (const auto collection : 383 *model_->output_options()->filtered_collections_annotation()) { 384 filtered_collections_annotation_.insert(collection->str()); 385 } 386 } 387 if (model_->output_options()->filtered_collections_classification()) { 388 for (const auto collection : 389 *model_->output_options()->filtered_collections_classification()) { 390 filtered_collections_classification_.insert(collection->str()); 391 } 392 } 393 if (model_->output_options()->filtered_collections_selection()) { 394 for (const auto collection : 395 *model_->output_options()->filtered_collections_selection()) { 396 filtered_collections_selection_.insert(collection->str()); 397 } 398 } 399 } 400 401 if (model_->number_annotator_options() && 402 model_->number_annotator_options()->enabled()) { 403 if (selection_feature_processor_ == nullptr) { 404 TC3_LOG(ERROR) 405 << "Could not initialize NumberAnnotator without a feature processor"; 406 return; 407 } 408 409 number_annotator_.reset( 410 new NumberAnnotator(model_->number_annotator_options(), 411 selection_feature_processor_.get())); 412 } 413 414 if (model_->duration_annotator_options() && 415 model_->duration_annotator_options()->enabled()) { 416 duration_annotator_.reset( 417 new DurationAnnotator(model_->duration_annotator_options(), 418 selection_feature_processor_.get())); 419 } 420 421 if (model_->entity_data_schema()) { 422 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>( 423 model_->entity_data_schema()->Data(), 424 model_->entity_data_schema()->size()); 425 if (entity_data_schema_ == nullptr) { 426 TC3_LOG(ERROR) << "Could not load entity data schema data."; 427 return; 428 } 429 430 entity_data_builder_.reset( 431 new ReflectiveFlatbufferBuilder(entity_data_schema_)); 432 } else { 433 entity_data_schema_ = nullptr; 434 entity_data_builder_ = nullptr; 435 } 436 437 if (model_->triggering_locales() && 438 !ParseLocales(model_->triggering_locales()->c_str(), 439 &model_triggering_locales_)) { 440 TC3_LOG(ERROR) << "Could not parse model supported locales."; 441 return; 442 } 443 444 if (model_->triggering_options() != nullptr && 445 model_->triggering_options()->locales() != nullptr && 446 !ParseLocales(model_->triggering_options()->locales()->c_str(), 447 &ml_model_triggering_locales_)) { 448 TC3_LOG(ERROR) << "Could not parse supported ML model locales."; 449 return; 450 } 451 452 if (model_->triggering_options() != nullptr && 453 model_->triggering_options()->dictionary_locales() != nullptr && 454 !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(), 455 &dictionary_locales_)) { 456 TC3_LOG(ERROR) << "Could not parse dictionary supported locales."; 457 return; 458 } 459 460 initialized_ = true; 461 } 462 463 bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) { 464 if (!model_->regex_model()->patterns()) { 465 return true; 466 } 467 468 // Initialize pattern recognizers. 469 int regex_pattern_id = 0; 470 for (const auto& regex_pattern : *model_->regex_model()->patterns()) { 471 std::unique_ptr<UniLib::RegexPattern> compiled_pattern = 472 UncompressMakeRegexPattern( 473 *unilib_, regex_pattern->pattern(), 474 regex_pattern->compressed_pattern(), 475 model_->regex_model()->lazy_regex_compilation(), decompressor); 476 if (!compiled_pattern) { 477 TC3_LOG(INFO) << "Failed to load regex pattern"; 478 return false; 479 } 480 481 if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) { 482 annotation_regex_patterns_.push_back(regex_pattern_id); 483 } 484 if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) { 485 classification_regex_patterns_.push_back(regex_pattern_id); 486 } 487 if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) { 488 selection_regex_patterns_.push_back(regex_pattern_id); 489 } 490 regex_patterns_.push_back({ 491 regex_pattern, 492 std::move(compiled_pattern), 493 }); 494 ++regex_pattern_id; 495 } 496 497 return true; 498 } 499 500 bool Annotator::InitializeKnowledgeEngine( 501 const std::string& serialized_config) { 502 std::unique_ptr<KnowledgeEngine> knowledge_engine( 503 new KnowledgeEngine(unilib_)); 504 if (!knowledge_engine->Initialize(serialized_config)) { 505 TC3_LOG(ERROR) << "Failed to initialize the knowledge engine."; 506 return false; 507 } 508 knowledge_engine_ = std::move(knowledge_engine); 509 return true; 510 } 511 512 bool Annotator::InitializeContactEngine(const std::string& serialized_config) { 513 std::unique_ptr<ContactEngine> contact_engine( 514 new ContactEngine(selection_feature_processor_.get(), unilib_)); 515 if (!contact_engine->Initialize(serialized_config)) { 516 TC3_LOG(ERROR) << "Failed to initialize the contact engine."; 517 return false; 518 } 519 contact_engine_ = std::move(contact_engine); 520 return true; 521 } 522 523 bool Annotator::InitializeInstalledAppEngine( 524 const std::string& serialized_config) { 525 std::unique_ptr<InstalledAppEngine> installed_app_engine( 526 new InstalledAppEngine(selection_feature_processor_.get(), unilib_)); 527 if (!installed_app_engine->Initialize(serialized_config)) { 528 TC3_LOG(ERROR) << "Failed to initialize the installed app engine."; 529 return false; 530 } 531 installed_app_engine_ = std::move(installed_app_engine); 532 return true; 533 } 534 535 namespace { 536 537 int CountDigits(const std::string& str, CodepointSpan selection_indices) { 538 int count = 0; 539 int i = 0; 540 const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false); 541 for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) { 542 if (i >= selection_indices.first && i < selection_indices.second && 543 isdigit(*it)) { 544 ++count; 545 } 546 } 547 return count; 548 } 549 550 } // namespace 551 552 namespace internal { 553 // Helper function, which if the initial 'span' contains only white-spaces, 554 // moves the selection to a single-codepoint selection on a left or right side 555 // of this space. 556 CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span, 557 const UnicodeText& context_unicode, 558 const UniLib& unilib) { 559 TC3_CHECK(ValidNonEmptySpan(span)); 560 561 UnicodeText::const_iterator it; 562 563 // Check that the current selection is all whitespaces. 564 it = context_unicode.begin(); 565 std::advance(it, span.first); 566 for (int i = 0; i < (span.second - span.first); ++i, ++it) { 567 if (!unilib.IsWhitespace(*it)) { 568 return span; 569 } 570 } 571 572 CodepointSpan result; 573 574 // Try moving left. 575 result = span; 576 it = context_unicode.begin(); 577 std::advance(it, span.first); 578 while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) { 579 --result.first; 580 --it; 581 } 582 result.second = result.first + 1; 583 if (!unilib.IsWhitespace(*it)) { 584 return result; 585 } 586 587 // If moving left didn't find a non-whitespace character, just return the 588 // original span. 589 return span; 590 } 591 } // namespace internal 592 593 bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const { 594 return !span.classification.empty() && 595 filtered_collections_annotation_.find( 596 span.classification[0].collection) != 597 filtered_collections_annotation_.end(); 598 } 599 600 bool Annotator::FilteredForClassification( 601 const ClassificationResult& classification) const { 602 return filtered_collections_classification_.find(classification.collection) != 603 filtered_collections_classification_.end(); 604 } 605 606 bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const { 607 return !span.classification.empty() && 608 filtered_collections_selection_.find( 609 span.classification[0].collection) != 610 filtered_collections_selection_.end(); 611 } 612 613 namespace { 614 inline bool ClassifiedAsOther( 615 const std::vector<ClassificationResult>& classification) { 616 return !classification.empty() && 617 classification[0].collection == Collections::Other(); 618 } 619 620 float GetPriorityScore( 621 const std::vector<ClassificationResult>& classification) { 622 if (!classification.empty() && !ClassifiedAsOther(classification)) { 623 return classification[0].priority_score; 624 } else { 625 return -1.0; 626 } 627 } 628 } // namespace 629 630 bool Annotator::VerifyRegexMatchCandidate( 631 const std::string& context, const VerificationOptions* verification_options, 632 const std::string& match, const UniLib::RegexMatcher* matcher) const { 633 if (verification_options == nullptr) { 634 return true; 635 } 636 if (verification_options->verify_luhn_checksum() && 637 !VerifyLuhnChecksum(match)) { 638 return false; 639 } 640 const int lua_verifier = verification_options->lua_verifier(); 641 if (lua_verifier >= 0) { 642 if (model_->regex_model()->lua_verifier() == nullptr || 643 lua_verifier >= model_->regex_model()->lua_verifier()->size()) { 644 TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier; 645 return false; 646 } 647 return VerifyMatch( 648 context, matcher, 649 model_->regex_model()->lua_verifier()->Get(lua_verifier)->str()); 650 } 651 return true; 652 } 653 654 CodepointSpan Annotator::SuggestSelection( 655 const std::string& context, CodepointSpan click_indices, 656 const SelectionOptions& options) const { 657 CodepointSpan original_click_indices = click_indices; 658 if (!initialized_) { 659 TC3_LOG(ERROR) << "Not initialized"; 660 return original_click_indices; 661 } 662 if (!(model_->enabled_modes() & ModeFlag_SELECTION)) { 663 return original_click_indices; 664 } 665 666 std::vector<Locale> detected_text_language_tags; 667 if (!ParseLocales(options.detected_text_language_tags, 668 &detected_text_language_tags)) { 669 TC3_LOG(WARNING) 670 << "Failed to parse the detected_text_language_tags in options: " 671 << options.detected_text_language_tags; 672 } 673 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags, 674 model_triggering_locales_, 675 /*default_value=*/true)) { 676 return original_click_indices; 677 } 678 679 const UnicodeText context_unicode = UTF8ToUnicodeText(context, 680 /*do_copy=*/false); 681 682 if (!context_unicode.is_valid()) { 683 return original_click_indices; 684 } 685 686 const int context_codepoint_size = context_unicode.size_codepoints(); 687 688 if (click_indices.first < 0 || click_indices.second < 0 || 689 click_indices.first >= context_codepoint_size || 690 click_indices.second > context_codepoint_size || 691 click_indices.first >= click_indices.second) { 692 TC3_VLOG(1) << "Trying to run SuggestSelection with invalid indices: " 693 << click_indices.first << " " << click_indices.second; 694 return original_click_indices; 695 } 696 697 if (model_->snap_whitespace_selections()) { 698 // We want to expand a purely white-space selection to a multi-selection it 699 // would've been part of. But with this feature disabled we would do a no- 700 // op, because no token is found. Therefore, we need to modify the 701 // 'click_indices' a bit to include a part of the token, so that the click- 702 // finding logic finds the clicked token correctly. This modification is 703 // done by the following function. Note, that it's enough to check the left 704 // side of the current selection, because if the white-space is a part of a 705 // multi-selection, necessarily both tokens - on the left and the right 706 // sides need to be selected. Thus snapping only to the left is sufficient 707 // (there's a check at the bottom that makes sure that if we snap to the 708 // left token but the result does not contain the initial white-space, 709 // returns the original indices). 710 click_indices = internal::SnapLeftIfWhitespaceSelection( 711 click_indices, context_unicode, *unilib_); 712 } 713 714 std::vector<AnnotatedSpan> candidates; 715 InterpreterManager interpreter_manager(selection_executor_.get(), 716 classification_executor_.get()); 717 std::vector<Token> tokens; 718 if (!ModelSuggestSelection(context_unicode, click_indices, 719 detected_text_language_tags, &interpreter_manager, 720 &tokens, &candidates)) { 721 TC3_LOG(ERROR) << "Model suggest selection failed."; 722 return original_click_indices; 723 } 724 if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates, 725 /*is_serialized_entity_data_enabled=*/false)) { 726 TC3_LOG(ERROR) << "Regex suggest selection failed."; 727 return original_click_indices; 728 } 729 if (!DatetimeChunk( 730 UTF8ToUnicodeText(context, /*do_copy=*/false), 731 /*reference_time_ms_utc=*/0, /*reference_timezone=*/"", 732 options.locales, ModeFlag_SELECTION, options.annotation_usecase, 733 /*is_serialized_entity_data_enabled=*/false, &candidates)) { 734 TC3_LOG(ERROR) << "Datetime suggest selection failed."; 735 return original_click_indices; 736 } 737 if (knowledge_engine_ != nullptr && 738 !knowledge_engine_->Chunk(context, &candidates)) { 739 TC3_LOG(ERROR) << "Knowledge suggest selection failed."; 740 return original_click_indices; 741 } 742 if (contact_engine_ != nullptr && 743 !contact_engine_->Chunk(context_unicode, tokens, &candidates)) { 744 TC3_LOG(ERROR) << "Contact suggest selection failed."; 745 return original_click_indices; 746 } 747 if (installed_app_engine_ != nullptr && 748 !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) { 749 TC3_LOG(ERROR) << "Installed app suggest selection failed."; 750 return original_click_indices; 751 } 752 if (number_annotator_ != nullptr && 753 !number_annotator_->FindAll(context_unicode, options.annotation_usecase, 754 &candidates)) { 755 TC3_LOG(ERROR) << "Number annotator failed in suggest selection."; 756 return original_click_indices; 757 } 758 if (duration_annotator_ != nullptr && 759 !duration_annotator_->FindAll(context_unicode, tokens, 760 options.annotation_usecase, &candidates)) { 761 TC3_LOG(ERROR) << "Duration annotator failed in suggest selection."; 762 return original_click_indices; 763 } 764 765 // Sort candidates according to their position in the input, so that the next 766 // code can assume that any connected component of overlapping spans forms a 767 // contiguous block. 768 std::sort(candidates.begin(), candidates.end(), 769 [](const AnnotatedSpan& a, const AnnotatedSpan& b) { 770 return a.span.first < b.span.first; 771 }); 772 773 std::vector<int> candidate_indices; 774 if (!ResolveConflicts(candidates, context, tokens, 775 detected_text_language_tags, options.annotation_usecase, 776 &interpreter_manager, &candidate_indices)) { 777 TC3_LOG(ERROR) << "Couldn't resolve conflicts."; 778 return original_click_indices; 779 } 780 781 std::sort(candidate_indices.begin(), candidate_indices.end(), 782 [&candidates](int a, int b) { 783 return GetPriorityScore(candidates[a].classification) > 784 GetPriorityScore(candidates[b].classification); 785 }); 786 787 for (const int i : candidate_indices) { 788 if (SpansOverlap(candidates[i].span, click_indices) && 789 SpansOverlap(candidates[i].span, original_click_indices)) { 790 // Run model classification if not present but requested and there's a 791 // classification collection filter specified. 792 if (candidates[i].classification.empty() && 793 model_->selection_options()->always_classify_suggested_selection() && 794 !filtered_collections_selection_.empty()) { 795 if (!ModelClassifyText(context, detected_text_language_tags, 796 candidates[i].span, &interpreter_manager, 797 /*embedding_cache=*/nullptr, 798 &candidates[i].classification)) { 799 return original_click_indices; 800 } 801 } 802 803 // Ignore if span classification is filtered. 804 if (FilteredForSelection(candidates[i])) { 805 return original_click_indices; 806 } 807 808 return candidates[i].span; 809 } 810 } 811 812 return original_click_indices; 813 } 814 815 namespace { 816 // Helper function that returns the index of the first candidate that 817 // transitively does not overlap with the candidate on 'start_index'. If the end 818 // of 'candidates' is reached, it returns the index that points right behind the 819 // array. 820 int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates, 821 int start_index) { 822 int first_non_overlapping = start_index + 1; 823 CodepointSpan conflicting_span = candidates[start_index].span; 824 while ( 825 first_non_overlapping < candidates.size() && 826 SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) { 827 // Grow the span to include the current one. 828 conflicting_span.second = std::max( 829 conflicting_span.second, candidates[first_non_overlapping].span.second); 830 831 ++first_non_overlapping; 832 } 833 return first_non_overlapping; 834 } 835 } // namespace 836 837 bool Annotator::ResolveConflicts( 838 const std::vector<AnnotatedSpan>& candidates, const std::string& context, 839 const std::vector<Token>& cached_tokens, 840 const std::vector<Locale>& detected_text_language_tags, 841 AnnotationUsecase annotation_usecase, 842 InterpreterManager* interpreter_manager, std::vector<int>* result) const { 843 result->clear(); 844 result->reserve(candidates.size()); 845 for (int i = 0; i < candidates.size();) { 846 int first_non_overlapping = 847 FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i); 848 849 const bool conflict_found = first_non_overlapping != (i + 1); 850 if (conflict_found) { 851 std::vector<int> candidate_indices; 852 if (!ResolveConflict(context, cached_tokens, candidates, 853 detected_text_language_tags, i, 854 first_non_overlapping, annotation_usecase, 855 interpreter_manager, &candidate_indices)) { 856 return false; 857 } 858 result->insert(result->end(), candidate_indices.begin(), 859 candidate_indices.end()); 860 } else { 861 result->push_back(i); 862 } 863 864 // Skip over the whole conflicting group/go to next candidate. 865 i = first_non_overlapping; 866 } 867 return true; 868 } 869 870 namespace { 871 // Returns true, if the given two sources do conflict in given annotation 872 // usecase. 873 // - In SMART usecase, all sources do conflict, because there's only 1 possible 874 // annotation for a given span. 875 // - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime 876 // and duration), while others not (e.g. duration and number). 877 bool DoSourcesConflict(AnnotationUsecase annotation_usecase, 878 const AnnotatedSpan::Source source1, 879 const AnnotatedSpan::Source source2) { 880 uint32 source_mask = 881 (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2)); 882 883 switch (annotation_usecase) { 884 case AnnotationUsecase_ANNOTATION_USECASE_SMART: 885 // In the SMART mode, all annotations conflict. 886 return true; 887 888 case AnnotationUsecase_ANNOTATION_USECASE_RAW: 889 // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours", 890 // can have two non-conflicting annotations: "in 3 hours" (datetime), "3 891 // hours" (duration). 892 if ((source_mask & 893 (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) && 894 (source_mask & 895 (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) { 896 return false; 897 } 898 899 // A KNOWLEDGE entity does not conflict with anything. 900 if ((source_mask & 901 (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) { 902 return false; 903 } 904 905 // Entities from other sources can conflict. 906 return true; 907 } 908 } 909 } // namespace 910 911 bool Annotator::ResolveConflict( 912 const std::string& context, const std::vector<Token>& cached_tokens, 913 const std::vector<AnnotatedSpan>& candidates, 914 const std::vector<Locale>& detected_text_language_tags, int start_index, 915 int end_index, AnnotationUsecase annotation_usecase, 916 InterpreterManager* interpreter_manager, 917 std::vector<int>* chosen_indices) const { 918 std::vector<int> conflicting_indices; 919 std::unordered_map<int, float> scores; 920 for (int i = start_index; i < end_index; ++i) { 921 conflicting_indices.push_back(i); 922 if (!candidates[i].classification.empty()) { 923 scores[i] = GetPriorityScore(candidates[i].classification); 924 continue; 925 } 926 927 // OPTIMIZATION: So that we don't have to classify all the ML model 928 // spans apriori, we wait until we get here, when they conflict with 929 // something and we need the actual classification scores. So if the 930 // candidate conflicts and comes from the model, we need to run a 931 // classification to determine its priority: 932 std::vector<ClassificationResult> classification; 933 if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags, 934 candidates[i].span, interpreter_manager, 935 /*embedding_cache=*/nullptr, &classification)) { 936 return false; 937 } 938 939 if (!classification.empty()) { 940 scores[i] = GetPriorityScore(classification); 941 } 942 } 943 944 std::sort(conflicting_indices.begin(), conflicting_indices.end(), 945 [&scores](int i, int j) { return scores[i] > scores[j]; }); 946 947 // Here we keep a set of indices that were chosen, per-source, to enable 948 // effective computation. 949 std::unordered_map<AnnotatedSpan::Source, SortedIntSet> 950 chosen_indices_for_source_map; 951 952 // Greedily place the candidates if they don't conflict with the already 953 // placed ones. 954 for (int i = 0; i < conflicting_indices.size(); ++i) { 955 const int considered_candidate = conflicting_indices[i]; 956 957 // See if there is a conflict between the candidate and all already placed 958 // candidates. 959 bool conflict = false; 960 SortedIntSet* chosen_indices_for_source_ptr = nullptr; 961 for (auto& source_set_pair : chosen_indices_for_source_map) { 962 if (source_set_pair.first == candidates[considered_candidate].source) { 963 chosen_indices_for_source_ptr = &source_set_pair.second; 964 } 965 966 if (DoSourcesConflict(annotation_usecase, source_set_pair.first, 967 candidates[considered_candidate].source) && 968 DoesCandidateConflict(considered_candidate, candidates, 969 source_set_pair.second)) { 970 conflict = true; 971 break; 972 } 973 } 974 975 // Skip the candidate if a conflict was found. 976 if (conflict) { 977 continue; 978 } 979 980 // If the set of indices for the current source doesn't exist yet, 981 // initialize it. 982 if (chosen_indices_for_source_ptr == nullptr) { 983 SortedIntSet new_set([&candidates](int a, int b) { 984 return candidates[a].span.first < candidates[b].span.first; 985 }); 986 chosen_indices_for_source_map[candidates[considered_candidate].source] = 987 std::move(new_set); 988 chosen_indices_for_source_ptr = 989 &chosen_indices_for_source_map[candidates[considered_candidate] 990 .source]; 991 } 992 993 // Place the candidate to the output and to the per-source conflict set. 994 chosen_indices->push_back(considered_candidate); 995 chosen_indices_for_source_ptr->insert(considered_candidate); 996 } 997 998 std::sort(chosen_indices->begin(), chosen_indices->end()); 999 1000 return true; 1001 } 1002 1003 bool Annotator::ModelSuggestSelection( 1004 const UnicodeText& context_unicode, CodepointSpan click_indices, 1005 const std::vector<Locale>& detected_text_language_tags, 1006 InterpreterManager* interpreter_manager, std::vector<Token>* tokens, 1007 std::vector<AnnotatedSpan>* result) const { 1008 if (model_->triggering_options() == nullptr || 1009 !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) { 1010 return true; 1011 } 1012 1013 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags, 1014 ml_model_triggering_locales_, 1015 /*default_value=*/true)) { 1016 return true; 1017 } 1018 1019 int click_pos; 1020 *tokens = selection_feature_processor_->Tokenize(context_unicode); 1021 selection_feature_processor_->RetokenizeAndFindClick( 1022 context_unicode, click_indices, 1023 selection_feature_processor_->GetOptions()->only_use_line_with_click(), 1024 tokens, &click_pos); 1025 if (click_pos == kInvalidIndex) { 1026 TC3_VLOG(1) << "Could not calculate the click position."; 1027 return false; 1028 } 1029 1030 const int symmetry_context_size = 1031 model_->selection_options()->symmetry_context_size(); 1032 const FeatureProcessorOptions_::BoundsSensitiveFeatures* 1033 bounds_sensitive_features = selection_feature_processor_->GetOptions() 1034 ->bounds_sensitive_features(); 1035 1036 // The symmetry context span is the clicked token with symmetry_context_size 1037 // tokens on either side. 1038 const TokenSpan symmetry_context_span = IntersectTokenSpans( 1039 ExpandTokenSpan(SingleTokenSpan(click_pos), 1040 /*num_tokens_left=*/symmetry_context_size, 1041 /*num_tokens_right=*/symmetry_context_size), 1042 {0, tokens->size()}); 1043 1044 // Compute the extraction span based on the model type. 1045 TokenSpan extraction_span; 1046 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { 1047 // The extraction span is the symmetry context span expanded to include 1048 // max_selection_span tokens on either side, which is how far a selection 1049 // can stretch from the click, plus a relevant number of tokens outside of 1050 // the bounds of the selection. 1051 const int max_selection_span = 1052 selection_feature_processor_->GetOptions()->max_selection_span(); 1053 extraction_span = 1054 ExpandTokenSpan(symmetry_context_span, 1055 /*num_tokens_left=*/max_selection_span + 1056 bounds_sensitive_features->num_tokens_before(), 1057 /*num_tokens_right=*/max_selection_span + 1058 bounds_sensitive_features->num_tokens_after()); 1059 } else { 1060 // The extraction span is the symmetry context span expanded to include 1061 // context_size tokens on either side. 1062 const int context_size = 1063 selection_feature_processor_->GetOptions()->context_size(); 1064 extraction_span = ExpandTokenSpan(symmetry_context_span, 1065 /*num_tokens_left=*/context_size, 1066 /*num_tokens_right=*/context_size); 1067 } 1068 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()}); 1069 1070 if (!selection_feature_processor_->HasEnoughSupportedCodepoints( 1071 *tokens, extraction_span)) { 1072 return true; 1073 } 1074 1075 std::unique_ptr<CachedFeatures> cached_features; 1076 if (!selection_feature_processor_->ExtractFeatures( 1077 *tokens, extraction_span, 1078 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex}, 1079 embedding_executor_.get(), 1080 /*embedding_cache=*/nullptr, 1081 selection_feature_processor_->EmbeddingSize() + 1082 selection_feature_processor_->DenseFeaturesCount(), 1083 &cached_features)) { 1084 TC3_LOG(ERROR) << "Could not extract features."; 1085 return false; 1086 } 1087 1088 // Produce selection model candidates. 1089 std::vector<TokenSpan> chunks; 1090 if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span, 1091 interpreter_manager->SelectionInterpreter(), *cached_features, 1092 &chunks)) { 1093 TC3_LOG(ERROR) << "Could not chunk."; 1094 return false; 1095 } 1096 1097 for (const TokenSpan& chunk : chunks) { 1098 AnnotatedSpan candidate; 1099 candidate.span = selection_feature_processor_->StripBoundaryCodepoints( 1100 context_unicode, TokenSpanToCodepointSpan(*tokens, chunk)); 1101 if (model_->selection_options()->strip_unpaired_brackets()) { 1102 candidate.span = 1103 StripUnpairedBrackets(context_unicode, candidate.span, *unilib_); 1104 } 1105 1106 // Only output non-empty spans. 1107 if (candidate.span.first != candidate.span.second) { 1108 result->push_back(candidate); 1109 } 1110 } 1111 return true; 1112 } 1113 1114 bool Annotator::ModelClassifyText( 1115 const std::string& context, 1116 const std::vector<Locale>& detected_text_language_tags, 1117 CodepointSpan selection_indices, InterpreterManager* interpreter_manager, 1118 FeatureProcessor::EmbeddingCache* embedding_cache, 1119 std::vector<ClassificationResult>* classification_results) const { 1120 return ModelClassifyText(context, {}, detected_text_language_tags, 1121 selection_indices, interpreter_manager, 1122 embedding_cache, classification_results); 1123 } 1124 1125 namespace internal { 1126 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens, 1127 CodepointSpan selection_indices, 1128 TokenSpan tokens_around_selection_to_copy) { 1129 const auto first_selection_token = std::upper_bound( 1130 cached_tokens.begin(), cached_tokens.end(), selection_indices.first, 1131 [](int selection_start, const Token& token) { 1132 return selection_start < token.end; 1133 }); 1134 const auto last_selection_token = std::lower_bound( 1135 cached_tokens.begin(), cached_tokens.end(), selection_indices.second, 1136 [](const Token& token, int selection_end) { 1137 return token.start < selection_end; 1138 }); 1139 1140 const int64 first_token = std::max( 1141 static_cast<int64>(0), 1142 static_cast<int64>((first_selection_token - cached_tokens.begin()) - 1143 tokens_around_selection_to_copy.first)); 1144 const int64 last_token = std::min( 1145 static_cast<int64>(cached_tokens.size()), 1146 static_cast<int64>((last_selection_token - cached_tokens.begin()) + 1147 tokens_around_selection_to_copy.second)); 1148 1149 std::vector<Token> tokens; 1150 tokens.reserve(last_token - first_token); 1151 for (int i = first_token; i < last_token; ++i) { 1152 tokens.push_back(cached_tokens[i]); 1153 } 1154 return tokens; 1155 } 1156 } // namespace internal 1157 1158 TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const { 1159 const FeatureProcessorOptions_::BoundsSensitiveFeatures* 1160 bounds_sensitive_features = 1161 classification_feature_processor_->GetOptions() 1162 ->bounds_sensitive_features(); 1163 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { 1164 // The extraction span is the selection span expanded to include a relevant 1165 // number of tokens outside of the bounds of the selection. 1166 return {bounds_sensitive_features->num_tokens_before(), 1167 bounds_sensitive_features->num_tokens_after()}; 1168 } else { 1169 // The extraction span is the clicked token with context_size tokens on 1170 // either side. 1171 const int context_size = 1172 selection_feature_processor_->GetOptions()->context_size(); 1173 return {context_size, context_size}; 1174 } 1175 } 1176 1177 namespace { 1178 // Sorts the classification results from high score to low score. 1179 void SortClassificationResults( 1180 std::vector<ClassificationResult>* classification_results) { 1181 std::sort(classification_results->begin(), classification_results->end(), 1182 [](const ClassificationResult& a, const ClassificationResult& b) { 1183 return a.score > b.score; 1184 }); 1185 } 1186 } // namespace 1187 1188 bool Annotator::ModelClassifyText( 1189 const std::string& context, const std::vector<Token>& cached_tokens, 1190 const std::vector<Locale>& detected_text_language_tags, 1191 CodepointSpan selection_indices, InterpreterManager* interpreter_manager, 1192 FeatureProcessor::EmbeddingCache* embedding_cache, 1193 std::vector<ClassificationResult>* classification_results) const { 1194 std::vector<Token> tokens; 1195 return ModelClassifyText(context, cached_tokens, detected_text_language_tags, 1196 selection_indices, interpreter_manager, 1197 embedding_cache, classification_results, &tokens); 1198 } 1199 1200 bool Annotator::ModelClassifyText( 1201 const std::string& context, const std::vector<Token>& cached_tokens, 1202 const std::vector<Locale>& detected_text_language_tags, 1203 CodepointSpan selection_indices, InterpreterManager* interpreter_manager, 1204 FeatureProcessor::EmbeddingCache* embedding_cache, 1205 std::vector<ClassificationResult>* classification_results, 1206 std::vector<Token>* tokens) const { 1207 if (model_->triggering_options() == nullptr || 1208 !(model_->triggering_options()->enabled_modes() & 1209 ModeFlag_CLASSIFICATION)) { 1210 return true; 1211 } 1212 1213 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags, 1214 ml_model_triggering_locales_, 1215 /*default_value=*/true)) { 1216 return true; 1217 } 1218 1219 if (cached_tokens.empty()) { 1220 *tokens = classification_feature_processor_->Tokenize(context); 1221 } else { 1222 *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices, 1223 ClassifyTextUpperBoundNeededTokens()); 1224 } 1225 1226 int click_pos; 1227 classification_feature_processor_->RetokenizeAndFindClick( 1228 context, selection_indices, 1229 classification_feature_processor_->GetOptions() 1230 ->only_use_line_with_click(), 1231 tokens, &click_pos); 1232 const TokenSpan selection_token_span = 1233 CodepointSpanToTokenSpan(*tokens, selection_indices); 1234 const int selection_num_tokens = TokenSpanSize(selection_token_span); 1235 if (model_->classification_options()->max_num_tokens() > 0 && 1236 model_->classification_options()->max_num_tokens() < 1237 selection_num_tokens) { 1238 *classification_results = {{Collections::Other(), 1.0}}; 1239 return true; 1240 } 1241 1242 const FeatureProcessorOptions_::BoundsSensitiveFeatures* 1243 bounds_sensitive_features = 1244 classification_feature_processor_->GetOptions() 1245 ->bounds_sensitive_features(); 1246 if (selection_token_span.first == kInvalidIndex || 1247 selection_token_span.second == kInvalidIndex) { 1248 TC3_LOG(ERROR) << "Could not determine span."; 1249 return false; 1250 } 1251 1252 // Compute the extraction span based on the model type. 1253 TokenSpan extraction_span; 1254 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { 1255 // The extraction span is the selection span expanded to include a relevant 1256 // number of tokens outside of the bounds of the selection. 1257 extraction_span = ExpandTokenSpan( 1258 selection_token_span, 1259 /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(), 1260 /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after()); 1261 } else { 1262 if (click_pos == kInvalidIndex) { 1263 TC3_LOG(ERROR) << "Couldn't choose a click position."; 1264 return false; 1265 } 1266 // The extraction span is the clicked token with context_size tokens on 1267 // either side. 1268 const int context_size = 1269 classification_feature_processor_->GetOptions()->context_size(); 1270 extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos), 1271 /*num_tokens_left=*/context_size, 1272 /*num_tokens_right=*/context_size); 1273 } 1274 extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()}); 1275 1276 if (!classification_feature_processor_->HasEnoughSupportedCodepoints( 1277 *tokens, extraction_span)) { 1278 *classification_results = {{Collections::Other(), 1.0}}; 1279 return true; 1280 } 1281 1282 std::unique_ptr<CachedFeatures> cached_features; 1283 if (!classification_feature_processor_->ExtractFeatures( 1284 *tokens, extraction_span, selection_indices, 1285 embedding_executor_.get(), embedding_cache, 1286 classification_feature_processor_->EmbeddingSize() + 1287 classification_feature_processor_->DenseFeaturesCount(), 1288 &cached_features)) { 1289 TC3_LOG(ERROR) << "Could not extract features."; 1290 return false; 1291 } 1292 1293 std::vector<float> features; 1294 features.reserve(cached_features->OutputFeaturesSize()); 1295 if (bounds_sensitive_features && bounds_sensitive_features->enabled()) { 1296 cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span, 1297 &features); 1298 } else { 1299 cached_features->AppendClickContextFeaturesForClick(click_pos, &features); 1300 } 1301 1302 TensorView<float> logits = classification_executor_->ComputeLogits( 1303 TensorView<float>(features.data(), 1304 {1, static_cast<int>(features.size())}), 1305 interpreter_manager->ClassificationInterpreter()); 1306 if (!logits.is_valid()) { 1307 TC3_LOG(ERROR) << "Couldn't compute logits."; 1308 return false; 1309 } 1310 1311 if (logits.dims() != 2 || logits.dim(0) != 1 || 1312 logits.dim(1) != classification_feature_processor_->NumCollections()) { 1313 TC3_LOG(ERROR) << "Mismatching output"; 1314 return false; 1315 } 1316 1317 const std::vector<float> scores = 1318 ComputeSoftmax(logits.data(), logits.dim(1)); 1319 1320 if (scores.empty()) { 1321 *classification_results = {{Collections::Other(), 1.0}}; 1322 return true; 1323 } 1324 1325 const int best_score_index = 1326 std::max_element(scores.begin(), scores.end()) - scores.begin(); 1327 const std::string top_collection = 1328 classification_feature_processor_->LabelToCollection(best_score_index); 1329 1330 // Sanity checks. 1331 if (top_collection == Collections::Phone()) { 1332 const int digit_count = CountDigits(context, selection_indices); 1333 if (digit_count < 1334 model_->classification_options()->phone_min_num_digits() || 1335 digit_count > 1336 model_->classification_options()->phone_max_num_digits()) { 1337 *classification_results = {{Collections::Other(), 1.0}}; 1338 return true; 1339 } 1340 } else if (top_collection == Collections::Address()) { 1341 if (selection_num_tokens < 1342 model_->classification_options()->address_min_num_tokens()) { 1343 *classification_results = {{Collections::Other(), 1.0}}; 1344 return true; 1345 } 1346 } else if (top_collection == Collections::Dictionary()) { 1347 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags, 1348 dictionary_locales_, 1349 /*default_value=*/false)) { 1350 *classification_results = {{Collections::Other(), 1.0}}; 1351 return true; 1352 } 1353 } 1354 1355 *classification_results = {{top_collection, 1.0, scores[best_score_index]}}; 1356 return true; 1357 } 1358 1359 bool Annotator::RegexClassifyText( 1360 const std::string& context, CodepointSpan selection_indices, 1361 std::vector<ClassificationResult>* classification_result) const { 1362 const std::string selection_text = 1363 UTF8ToUnicodeText(context, /*do_copy=*/false) 1364 .UTF8Substring(selection_indices.first, selection_indices.second); 1365 const UnicodeText selection_text_unicode( 1366 UTF8ToUnicodeText(selection_text, /*do_copy=*/false)); 1367 1368 // Check whether any of the regular expressions match. 1369 for (const int pattern_id : classification_regex_patterns_) { 1370 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id]; 1371 const std::unique_ptr<UniLib::RegexMatcher> matcher = 1372 regex_pattern.pattern->Matcher(selection_text_unicode); 1373 int status = UniLib::RegexMatcher::kNoError; 1374 bool matches; 1375 if (regex_pattern.config->use_approximate_matching()) { 1376 matches = matcher->ApproximatelyMatches(&status); 1377 } else { 1378 matches = matcher->Matches(&status); 1379 } 1380 if (status != UniLib::RegexMatcher::kNoError) { 1381 return false; 1382 } 1383 if (matches && VerifyRegexMatchCandidate( 1384 context, regex_pattern.config->verification_options(), 1385 selection_text, matcher.get())) { 1386 classification_result->push_back( 1387 {regex_pattern.config->collection_name()->str(), 1388 regex_pattern.config->target_classification_score(), 1389 regex_pattern.config->priority_score()}); 1390 if (!SerializedEntityDataFromRegexMatch( 1391 regex_pattern.config, matcher.get(), 1392 &classification_result->back().serialized_entity_data)) { 1393 TC3_LOG(ERROR) << "Could not get entity data."; 1394 return false; 1395 } 1396 } 1397 } 1398 1399 return true; 1400 } 1401 1402 namespace { 1403 std::string PickCollectionForDatetime( 1404 const DatetimeParseResult& datetime_parse_result) { 1405 switch (datetime_parse_result.granularity) { 1406 case GRANULARITY_HOUR: 1407 case GRANULARITY_MINUTE: 1408 case GRANULARITY_SECOND: 1409 return Collections::DateTime(); 1410 default: 1411 return Collections::Date(); 1412 } 1413 } 1414 1415 std::string CreateDatetimeSerializedEntityData( 1416 const DatetimeParseResult& parse_result) { 1417 EntityDataT entity_data; 1418 entity_data.datetime.reset(new EntityData_::DatetimeT()); 1419 entity_data.datetime->time_ms_utc = parse_result.time_ms_utc; 1420 entity_data.datetime->granularity = 1421 static_cast<EntityData_::Datetime_::Granularity>( 1422 parse_result.granularity); 1423 1424 flatbuffers::FlatBufferBuilder builder; 1425 FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data)); 1426 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), 1427 builder.GetSize()); 1428 } 1429 } // namespace 1430 1431 bool Annotator::DatetimeClassifyText( 1432 const std::string& context, CodepointSpan selection_indices, 1433 const ClassificationOptions& options, 1434 std::vector<ClassificationResult>* classification_results) const { 1435 if (!datetime_parser_) { 1436 return false; 1437 } 1438 1439 const std::string selection_text = 1440 UTF8ToUnicodeText(context, /*do_copy=*/false) 1441 .UTF8Substring(selection_indices.first, selection_indices.second); 1442 1443 std::vector<DatetimeParseResultSpan> datetime_spans; 1444 if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc, 1445 options.reference_timezone, options.locales, 1446 ModeFlag_CLASSIFICATION, 1447 options.annotation_usecase, 1448 /*anchor_start_end=*/true, &datetime_spans)) { 1449 TC3_LOG(ERROR) << "Error during parsing datetime."; 1450 return false; 1451 } 1452 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) { 1453 // Only consider the result valid if the selection and extracted datetime 1454 // spans exactly match. 1455 if (std::make_pair(datetime_span.span.first + selection_indices.first, 1456 datetime_span.span.second + selection_indices.first) == 1457 selection_indices) { 1458 for (const DatetimeParseResult& parse_result : datetime_span.data) { 1459 classification_results->emplace_back( 1460 PickCollectionForDatetime(parse_result), 1461 datetime_span.target_classification_score); 1462 classification_results->back().datetime_parse_result = parse_result; 1463 classification_results->back().serialized_entity_data = 1464 CreateDatetimeSerializedEntityData(parse_result); 1465 classification_results->back().priority_score = 1466 datetime_span.priority_score; 1467 } 1468 return true; 1469 } 1470 } 1471 return true; 1472 } 1473 1474 std::vector<ClassificationResult> Annotator::ClassifyText( 1475 const std::string& context, CodepointSpan selection_indices, 1476 const ClassificationOptions& options) const { 1477 if (!initialized_) { 1478 TC3_LOG(ERROR) << "Not initialized"; 1479 return {}; 1480 } 1481 1482 if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) { 1483 return {}; 1484 } 1485 1486 std::vector<Locale> detected_text_language_tags; 1487 if (!ParseLocales(options.detected_text_language_tags, 1488 &detected_text_language_tags)) { 1489 TC3_LOG(WARNING) 1490 << "Failed to parse the detected_text_language_tags in options: " 1491 << options.detected_text_language_tags; 1492 } 1493 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags, 1494 model_triggering_locales_, 1495 /*default_value=*/true)) { 1496 return {}; 1497 } 1498 1499 if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) { 1500 return {}; 1501 } 1502 1503 if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) { 1504 TC3_VLOG(1) << "Trying to run ClassifyText with invalid indices: " 1505 << std::get<0>(selection_indices) << " " 1506 << std::get<1>(selection_indices); 1507 return {}; 1508 } 1509 1510 // We'll accumulate a list of candidates, and pick the best candidate in the 1511 // end. 1512 std::vector<AnnotatedSpan> candidates; 1513 1514 // Try the knowledge engine. 1515 // TODO(b/126579108): Propagate error status. 1516 ClassificationResult knowledge_result; 1517 if (knowledge_engine_ && knowledge_engine_->ClassifyText( 1518 context, selection_indices, &knowledge_result)) { 1519 candidates.push_back({selection_indices, {knowledge_result}}); 1520 candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE; 1521 } 1522 1523 // Try the contact engine. 1524 // TODO(b/126579108): Propagate error status. 1525 ClassificationResult contact_result; 1526 if (contact_engine_ && contact_engine_->ClassifyText( 1527 context, selection_indices, &contact_result)) { 1528 candidates.push_back({selection_indices, {contact_result}}); 1529 } 1530 1531 // Try the installed app engine. 1532 // TODO(b/126579108): Propagate error status. 1533 ClassificationResult installed_app_result; 1534 if (installed_app_engine_ && 1535 installed_app_engine_->ClassifyText(context, selection_indices, 1536 &installed_app_result)) { 1537 candidates.push_back({selection_indices, {installed_app_result}}); 1538 } 1539 1540 // Try the regular expression models. 1541 std::vector<ClassificationResult> regex_results; 1542 if (!RegexClassifyText(context, selection_indices, ®ex_results)) { 1543 return {}; 1544 } 1545 for (const ClassificationResult& result : regex_results) { 1546 candidates.push_back({selection_indices, {result}}); 1547 } 1548 1549 // Try the date model. 1550 // 1551 // DatetimeClassifyText only returns the first result, which can however have 1552 // more interpretations. They are inserted in the candidates as a single 1553 // AnnotatedSpan, so that they get treated together by the conflict resolution 1554 // algorithm. 1555 std::vector<ClassificationResult> datetime_results; 1556 if (!DatetimeClassifyText(context, selection_indices, options, 1557 &datetime_results)) { 1558 return {}; 1559 } 1560 if (!datetime_results.empty()) { 1561 candidates.push_back({selection_indices, std::move(datetime_results)}); 1562 candidates.back().source = AnnotatedSpan::Source::DATETIME; 1563 } 1564 1565 // Try the number annotator. 1566 // TODO(b/126579108): Propagate error status. 1567 ClassificationResult number_annotator_result; 1568 if (number_annotator_ && 1569 number_annotator_->ClassifyText( 1570 UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices, 1571 options.annotation_usecase, &number_annotator_result)) { 1572 candidates.push_back({selection_indices, {number_annotator_result}}); 1573 } 1574 1575 // Try the duration annotator. 1576 ClassificationResult duration_annotator_result; 1577 if (duration_annotator_ && 1578 duration_annotator_->ClassifyText( 1579 UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices, 1580 options.annotation_usecase, &duration_annotator_result)) { 1581 candidates.push_back({selection_indices, {duration_annotator_result}}); 1582 candidates.back().source = AnnotatedSpan::Source::DURATION; 1583 } 1584 1585 // Try the ML model. 1586 // 1587 // The output of the model is considered as an exclusive 1-of-N choice. That's 1588 // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1 1589 // span for each candidate, like e.g. the regex model. 1590 InterpreterManager interpreter_manager(selection_executor_.get(), 1591 classification_executor_.get()); 1592 std::vector<ClassificationResult> model_results; 1593 std::vector<Token> tokens; 1594 if (!ModelClassifyText( 1595 context, /*cached_tokens=*/{}, detected_text_language_tags, 1596 selection_indices, &interpreter_manager, 1597 /*embedding_cache=*/nullptr, &model_results, &tokens)) { 1598 return {}; 1599 } 1600 if (!model_results.empty()) { 1601 candidates.push_back({selection_indices, std::move(model_results)}); 1602 } 1603 1604 std::vector<int> candidate_indices; 1605 if (!ResolveConflicts(candidates, context, tokens, 1606 detected_text_language_tags, options.annotation_usecase, 1607 &interpreter_manager, &candidate_indices)) { 1608 TC3_LOG(ERROR) << "Couldn't resolve conflicts."; 1609 return {}; 1610 } 1611 1612 std::vector<ClassificationResult> results; 1613 for (const int i : candidate_indices) { 1614 for (const ClassificationResult& result : candidates[i].classification) { 1615 if (!FilteredForClassification(result)) { 1616 results.push_back(result); 1617 } 1618 } 1619 } 1620 1621 // Sort results according to score. 1622 std::sort(results.begin(), results.end(), 1623 [](const ClassificationResult& a, const ClassificationResult& b) { 1624 return a.score > b.score; 1625 }); 1626 1627 if (results.empty()) { 1628 results = {{Collections::Other(), 1.0}}; 1629 } 1630 return results; 1631 } 1632 1633 bool Annotator::ModelAnnotate( 1634 const std::string& context, 1635 const std::vector<Locale>& detected_text_language_tags, 1636 InterpreterManager* interpreter_manager, std::vector<Token>* tokens, 1637 std::vector<AnnotatedSpan>* result) const { 1638 if (model_->triggering_options() == nullptr || 1639 !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) { 1640 return true; 1641 } 1642 1643 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags, 1644 ml_model_triggering_locales_, 1645 /*default_value=*/true)) { 1646 return true; 1647 } 1648 1649 const UnicodeText context_unicode = UTF8ToUnicodeText(context, 1650 /*do_copy=*/false); 1651 std::vector<UnicodeTextRange> lines; 1652 if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) { 1653 lines.push_back({context_unicode.begin(), context_unicode.end()}); 1654 } else { 1655 lines = selection_feature_processor_->SplitContext(context_unicode); 1656 } 1657 1658 const float min_annotate_confidence = 1659 (model_->triggering_options() != nullptr 1660 ? model_->triggering_options()->min_annotate_confidence() 1661 : 0.f); 1662 1663 for (const UnicodeTextRange& line : lines) { 1664 FeatureProcessor::EmbeddingCache embedding_cache; 1665 const std::string line_str = 1666 UnicodeText::UTF8Substring(line.first, line.second); 1667 1668 *tokens = selection_feature_processor_->Tokenize(line_str); 1669 selection_feature_processor_->RetokenizeAndFindClick( 1670 line_str, {0, std::distance(line.first, line.second)}, 1671 selection_feature_processor_->GetOptions()->only_use_line_with_click(), 1672 tokens, 1673 /*click_pos=*/nullptr); 1674 const TokenSpan full_line_span = {0, tokens->size()}; 1675 1676 // TODO(zilka): Add support for greater granularity of this check. 1677 if (!selection_feature_processor_->HasEnoughSupportedCodepoints( 1678 *tokens, full_line_span)) { 1679 continue; 1680 } 1681 1682 std::unique_ptr<CachedFeatures> cached_features; 1683 if (!selection_feature_processor_->ExtractFeatures( 1684 *tokens, full_line_span, 1685 /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex}, 1686 embedding_executor_.get(), 1687 /*embedding_cache=*/nullptr, 1688 selection_feature_processor_->EmbeddingSize() + 1689 selection_feature_processor_->DenseFeaturesCount(), 1690 &cached_features)) { 1691 TC3_LOG(ERROR) << "Could not extract features."; 1692 return false; 1693 } 1694 1695 std::vector<TokenSpan> local_chunks; 1696 if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span, 1697 interpreter_manager->SelectionInterpreter(), 1698 *cached_features, &local_chunks)) { 1699 TC3_LOG(ERROR) << "Could not chunk."; 1700 return false; 1701 } 1702 1703 const int offset = std::distance(context_unicode.begin(), line.first); 1704 for (const TokenSpan& chunk : local_chunks) { 1705 const CodepointSpan codepoint_span = 1706 selection_feature_processor_->StripBoundaryCodepoints( 1707 line_str, TokenSpanToCodepointSpan(*tokens, chunk)); 1708 1709 // Skip empty spans. 1710 if (codepoint_span.first != codepoint_span.second) { 1711 std::vector<ClassificationResult> classification; 1712 if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags, 1713 codepoint_span, interpreter_manager, 1714 &embedding_cache, &classification)) { 1715 TC3_LOG(ERROR) << "Could not classify text: " 1716 << (codepoint_span.first + offset) << " " 1717 << (codepoint_span.second + offset); 1718 return false; 1719 } 1720 1721 // Do not include the span if it's classified as "other". 1722 if (!classification.empty() && !ClassifiedAsOther(classification) && 1723 classification[0].score >= min_annotate_confidence) { 1724 AnnotatedSpan result_span; 1725 result_span.span = {codepoint_span.first + offset, 1726 codepoint_span.second + offset}; 1727 result_span.classification = std::move(classification); 1728 result->push_back(std::move(result_span)); 1729 } 1730 } 1731 } 1732 } 1733 return true; 1734 } 1735 1736 const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const { 1737 return selection_feature_processor_.get(); 1738 } 1739 1740 const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests() 1741 const { 1742 return classification_feature_processor_.get(); 1743 } 1744 1745 const DatetimeParser* Annotator::DatetimeParserForTests() const { 1746 return datetime_parser_.get(); 1747 } 1748 1749 void Annotator::RemoveNotEnabledEntityTypes( 1750 const EnabledEntityTypes& is_entity_type_enabled, 1751 std::vector<AnnotatedSpan>* annotated_spans) const { 1752 for (AnnotatedSpan& annotated_span : *annotated_spans) { 1753 std::vector<ClassificationResult>& classifications = 1754 annotated_span.classification; 1755 classifications.erase( 1756 std::remove_if(classifications.begin(), classifications.end(), 1757 [&is_entity_type_enabled]( 1758 const ClassificationResult& classification_result) { 1759 return !is_entity_type_enabled( 1760 classification_result.collection); 1761 }), 1762 classifications.end()); 1763 } 1764 annotated_spans->erase( 1765 std::remove_if(annotated_spans->begin(), annotated_spans->end(), 1766 [](const AnnotatedSpan& annotated_span) { 1767 return annotated_span.classification.empty(); 1768 }), 1769 annotated_spans->end()); 1770 } 1771 1772 std::vector<AnnotatedSpan> Annotator::Annotate( 1773 const std::string& context, const AnnotationOptions& options) const { 1774 std::vector<AnnotatedSpan> candidates; 1775 1776 if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) { 1777 return {}; 1778 } 1779 1780 const UnicodeText context_unicode = 1781 UTF8ToUnicodeText(context, /*do_copy=*/false); 1782 if (!context_unicode.is_valid()) { 1783 return {}; 1784 } 1785 1786 std::vector<Locale> detected_text_language_tags; 1787 if (!ParseLocales(options.detected_text_language_tags, 1788 &detected_text_language_tags)) { 1789 TC3_LOG(WARNING) 1790 << "Failed to parse the detected_text_language_tags in options: " 1791 << options.detected_text_language_tags; 1792 } 1793 if (!Locale::IsAnyLocaleSupported(detected_text_language_tags, 1794 model_triggering_locales_, 1795 /*default_value=*/true)) { 1796 return {}; 1797 } 1798 1799 InterpreterManager interpreter_manager(selection_executor_.get(), 1800 classification_executor_.get()); 1801 1802 // Annotate with the selection model. 1803 std::vector<Token> tokens; 1804 if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager, 1805 &tokens, &candidates)) { 1806 TC3_LOG(ERROR) << "Couldn't run ModelAnnotate."; 1807 return {}; 1808 } 1809 1810 // Annotate with the regular expression models. 1811 if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), 1812 annotation_regex_patterns_, &candidates, 1813 options.is_serialized_entity_data_enabled)) { 1814 TC3_LOG(ERROR) << "Couldn't run RegexChunk."; 1815 return {}; 1816 } 1817 1818 // Annotate with the datetime model. 1819 const EnabledEntityTypes is_entity_type_enabled(options.entity_types); 1820 if ((is_entity_type_enabled(Collections::Date()) || 1821 is_entity_type_enabled(Collections::DateTime())) && 1822 !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false), 1823 options.reference_time_ms_utc, options.reference_timezone, 1824 options.locales, ModeFlag_ANNOTATION, 1825 options.annotation_usecase, 1826 options.is_serialized_entity_data_enabled, &candidates)) { 1827 TC3_LOG(ERROR) << "Couldn't run RegexChunk."; 1828 return {}; 1829 } 1830 1831 // Annotate with the knowledge engine. 1832 if (knowledge_engine_ && !knowledge_engine_->Chunk(context, &candidates)) { 1833 TC3_LOG(ERROR) << "Couldn't run knowledge engine Chunk."; 1834 return {}; 1835 } 1836 1837 // Annotate with the contact engine. 1838 if (contact_engine_ && 1839 !contact_engine_->Chunk(context_unicode, tokens, &candidates)) { 1840 TC3_LOG(ERROR) << "Couldn't run contact engine Chunk."; 1841 return {}; 1842 } 1843 1844 // Annotate with the installed app engine. 1845 if (installed_app_engine_ && 1846 !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) { 1847 TC3_LOG(ERROR) << "Couldn't run installed app engine Chunk."; 1848 return {}; 1849 } 1850 1851 // Annotate with the number annotator. 1852 if (number_annotator_ != nullptr && 1853 !number_annotator_->FindAll(context_unicode, options.annotation_usecase, 1854 &candidates)) { 1855 TC3_LOG(ERROR) << "Couldn't run number annotator FindAll."; 1856 return {}; 1857 } 1858 1859 // Annotate with the duration annotator. 1860 if (is_entity_type_enabled(Collections::Duration()) && 1861 duration_annotator_ != nullptr && 1862 !duration_annotator_->FindAll(context_unicode, tokens, 1863 options.annotation_usecase, &candidates)) { 1864 TC3_LOG(ERROR) << "Couldn't run duration annotator FindAll."; 1865 return {}; 1866 } 1867 1868 // Sort candidates according to their position in the input, so that the next 1869 // code can assume that any connected component of overlapping spans forms a 1870 // contiguous block. 1871 std::sort(candidates.begin(), candidates.end(), 1872 [](const AnnotatedSpan& a, const AnnotatedSpan& b) { 1873 return a.span.first < b.span.first; 1874 }); 1875 1876 std::vector<int> candidate_indices; 1877 if (!ResolveConflicts(candidates, context, tokens, 1878 detected_text_language_tags, options.annotation_usecase, 1879 &interpreter_manager, &candidate_indices)) { 1880 TC3_LOG(ERROR) << "Couldn't resolve conflicts."; 1881 return {}; 1882 } 1883 1884 std::vector<AnnotatedSpan> result; 1885 result.reserve(candidate_indices.size()); 1886 AnnotatedSpan aggregated_span; 1887 for (const int i : candidate_indices) { 1888 if (candidates[i].span != aggregated_span.span) { 1889 if (!aggregated_span.classification.empty()) { 1890 result.push_back(std::move(aggregated_span)); 1891 } 1892 aggregated_span = 1893 AnnotatedSpan(candidates[i].span, /*arg_classification=*/{}); 1894 } 1895 if (candidates[i].classification.empty() || 1896 ClassifiedAsOther(candidates[i].classification) || 1897 FilteredForAnnotation(candidates[i])) { 1898 continue; 1899 } 1900 for (ClassificationResult& classification : candidates[i].classification) { 1901 aggregated_span.classification.push_back(std::move(classification)); 1902 } 1903 } 1904 if (!aggregated_span.classification.empty()) { 1905 result.push_back(std::move(aggregated_span)); 1906 } 1907 1908 // We generate all candidates and remove them later (with the exception of 1909 // date/time/duration entities) because there are complex interdependencies 1910 // between the entity types. E.g., the TLD of an email can be interpreted as a 1911 // URL, but most likely a user of the API does not want such annotations if 1912 // "url" is enabled and "email" is not. 1913 RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result); 1914 1915 for (AnnotatedSpan& annotated_span : result) { 1916 SortClassificationResults(&annotated_span.classification); 1917 } 1918 1919 return result; 1920 } 1921 1922 CodepointSpan Annotator::ComputeSelectionBoundaries( 1923 const UniLib::RegexMatcher* match, 1924 const RegexModel_::Pattern* config) const { 1925 if (config->capturing_group() == nullptr) { 1926 // Use first capturing group to specify the selection. 1927 int status = UniLib::RegexMatcher::kNoError; 1928 const CodepointSpan result = {match->Start(1, &status), 1929 match->End(1, &status)}; 1930 if (status != UniLib::RegexMatcher::kNoError) { 1931 return {kInvalidIndex, kInvalidIndex}; 1932 } 1933 return result; 1934 } 1935 1936 CodepointSpan result = {kInvalidIndex, kInvalidIndex}; 1937 const int num_groups = config->capturing_group()->size(); 1938 for (int i = 0; i < num_groups; i++) { 1939 if (!config->capturing_group()->Get(i)->extend_selection()) { 1940 continue; 1941 } 1942 1943 int status = UniLib::RegexMatcher::kNoError; 1944 // Check match and adjust bounds. 1945 const int group_start = match->Start(i, &status); 1946 const int group_end = match->End(i, &status); 1947 if (status != UniLib::RegexMatcher::kNoError) { 1948 return {kInvalidIndex, kInvalidIndex}; 1949 } 1950 if (group_start == kInvalidIndex || group_end == kInvalidIndex) { 1951 continue; 1952 } 1953 if (result.first == kInvalidIndex) { 1954 result = {group_start, group_end}; 1955 } else { 1956 result.first = std::min(result.first, group_start); 1957 result.second = std::max(result.second, group_end); 1958 } 1959 } 1960 return result; 1961 } 1962 1963 bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const { 1964 if (pattern->serialized_entity_data() != nullptr) { 1965 return true; 1966 } 1967 if (pattern->capturing_group() != nullptr) { 1968 for (const RegexModel_::Pattern_::CapturingGroup* group : 1969 *pattern->capturing_group()) { 1970 if (group->entity_field_path() != nullptr) { 1971 return true; 1972 } 1973 } 1974 } 1975 return false; 1976 } 1977 1978 bool Annotator::SerializedEntityDataFromRegexMatch( 1979 const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher, 1980 std::string* serialized_entity_data) const { 1981 if (!HasEntityData(pattern)) { 1982 serialized_entity_data->clear(); 1983 return true; 1984 } 1985 TC3_CHECK(entity_data_builder_ != nullptr); 1986 1987 std::unique_ptr<ReflectiveFlatbuffer> entity_data = 1988 entity_data_builder_->NewRoot(); 1989 1990 TC3_CHECK(entity_data != nullptr); 1991 1992 // Set static entity data. 1993 if (pattern->serialized_entity_data() != nullptr) { 1994 TC3_CHECK(entity_data != nullptr); 1995 entity_data->MergeFromSerializedFlatbuffer( 1996 StringPiece(pattern->serialized_entity_data()->c_str(), 1997 pattern->serialized_entity_data()->size())); 1998 } 1999 2000 // Add entity data from rule capturing groups. 2001 if (pattern->capturing_group() != nullptr) { 2002 const int num_groups = pattern->capturing_group()->size(); 2003 for (int i = 0; i < num_groups; i++) { 2004 const FlatbufferFieldPath* field_path = 2005 pattern->capturing_group()->Get(i)->entity_field_path(); 2006 if (field_path == nullptr) { 2007 continue; 2008 } 2009 TC3_CHECK(entity_data != nullptr); 2010 if (!SetFieldFromCapturingGroup(/*group_id=*/i, field_path, matcher, 2011 entity_data.get())) { 2012 TC3_LOG(ERROR) 2013 << "Could not set entity data from rule capturing group."; 2014 return false; 2015 } 2016 } 2017 } 2018 2019 *serialized_entity_data = entity_data->Serialize(); 2020 return true; 2021 } 2022 2023 bool Annotator::RegexChunk(const UnicodeText& context_unicode, 2024 const std::vector<int>& rules, 2025 std::vector<AnnotatedSpan>* result, 2026 bool is_serialized_entity_data_enabled) const { 2027 for (int pattern_id : rules) { 2028 const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id]; 2029 const auto matcher = regex_pattern.pattern->Matcher(context_unicode); 2030 if (!matcher) { 2031 TC3_LOG(ERROR) << "Could not get regex matcher for pattern: " 2032 << pattern_id; 2033 return false; 2034 } 2035 2036 int status = UniLib::RegexMatcher::kNoError; 2037 while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) { 2038 if (regex_pattern.config->verification_options()) { 2039 if (!VerifyRegexMatchCandidate( 2040 context_unicode.ToUTF8String(), 2041 regex_pattern.config->verification_options(), 2042 matcher->Group(1, &status).ToUTF8String(), matcher.get())) { 2043 continue; 2044 } 2045 } 2046 2047 std::string serialized_entity_data; 2048 if (is_serialized_entity_data_enabled) { 2049 if (!SerializedEntityDataFromRegexMatch( 2050 regex_pattern.config, matcher.get(), &serialized_entity_data)) { 2051 TC3_LOG(ERROR) << "Could not get entity data."; 2052 return false; 2053 } 2054 } 2055 2056 result->emplace_back(); 2057 2058 // Selection/annotation regular expressions need to specify a capturing 2059 // group specifying the selection. 2060 result->back().span = 2061 ComputeSelectionBoundaries(matcher.get(), regex_pattern.config); 2062 2063 result->back().classification = { 2064 {regex_pattern.config->collection_name()->str(), 2065 regex_pattern.config->target_classification_score(), 2066 regex_pattern.config->priority_score()}}; 2067 2068 result->back().classification[0].serialized_entity_data = 2069 serialized_entity_data; 2070 } 2071 } 2072 return true; 2073 } 2074 2075 bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest, 2076 tflite::Interpreter* selection_interpreter, 2077 const CachedFeatures& cached_features, 2078 std::vector<TokenSpan>* chunks) const { 2079 const int max_selection_span = 2080 selection_feature_processor_->GetOptions()->max_selection_span(); 2081 // The inference span is the span of interest expanded to include 2082 // max_selection_span tokens on either side, which is how far a selection can 2083 // stretch from the click. 2084 const TokenSpan inference_span = IntersectTokenSpans( 2085 ExpandTokenSpan(span_of_interest, 2086 /*num_tokens_left=*/max_selection_span, 2087 /*num_tokens_right=*/max_selection_span), 2088 {0, num_tokens}); 2089 2090 std::vector<ScoredChunk> scored_chunks; 2091 if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() && 2092 selection_feature_processor_->GetOptions() 2093 ->bounds_sensitive_features() 2094 ->enabled()) { 2095 if (!ModelBoundsSensitiveScoreChunks( 2096 num_tokens, span_of_interest, inference_span, cached_features, 2097 selection_interpreter, &scored_chunks)) { 2098 return false; 2099 } 2100 } else { 2101 if (!ModelClickContextScoreChunks(num_tokens, span_of_interest, 2102 cached_features, selection_interpreter, 2103 &scored_chunks)) { 2104 return false; 2105 } 2106 } 2107 std::sort(scored_chunks.rbegin(), scored_chunks.rend(), 2108 [](const ScoredChunk& lhs, const ScoredChunk& rhs) { 2109 return lhs.score < rhs.score; 2110 }); 2111 2112 // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick 2113 // them greedily as long as they do not overlap with any previously picked 2114 // chunks. 2115 std::vector<bool> token_used(TokenSpanSize(inference_span)); 2116 chunks->clear(); 2117 for (const ScoredChunk& scored_chunk : scored_chunks) { 2118 bool feasible = true; 2119 for (int i = scored_chunk.token_span.first; 2120 i < scored_chunk.token_span.second; ++i) { 2121 if (token_used[i - inference_span.first]) { 2122 feasible = false; 2123 break; 2124 } 2125 } 2126 2127 if (!feasible) { 2128 continue; 2129 } 2130 2131 for (int i = scored_chunk.token_span.first; 2132 i < scored_chunk.token_span.second; ++i) { 2133 token_used[i - inference_span.first] = true; 2134 } 2135 2136 chunks->push_back(scored_chunk.token_span); 2137 } 2138 2139 std::sort(chunks->begin(), chunks->end()); 2140 2141 return true; 2142 } 2143 2144 namespace { 2145 // Updates the value at the given key in the map to maximum of the current value 2146 // and the given value, or simply inserts the value if the key is not yet there. 2147 template <typename Map> 2148 void UpdateMax(Map* map, typename Map::key_type key, 2149 typename Map::mapped_type value) { 2150 const auto it = map->find(key); 2151 if (it != map->end()) { 2152 it->second = std::max(it->second, value); 2153 } else { 2154 (*map)[key] = value; 2155 } 2156 } 2157 } // namespace 2158 2159 bool Annotator::ModelClickContextScoreChunks( 2160 int num_tokens, const TokenSpan& span_of_interest, 2161 const CachedFeatures& cached_features, 2162 tflite::Interpreter* selection_interpreter, 2163 std::vector<ScoredChunk>* scored_chunks) const { 2164 const int max_batch_size = model_->selection_options()->batch_size(); 2165 2166 std::vector<float> all_features; 2167 std::map<TokenSpan, float> chunk_scores; 2168 for (int batch_start = span_of_interest.first; 2169 batch_start < span_of_interest.second; batch_start += max_batch_size) { 2170 const int batch_end = 2171 std::min(batch_start + max_batch_size, span_of_interest.second); 2172 2173 // Prepare features for the whole batch. 2174 all_features.clear(); 2175 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize()); 2176 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) { 2177 cached_features.AppendClickContextFeaturesForClick(click_pos, 2178 &all_features); 2179 } 2180 2181 // Run batched inference. 2182 const int batch_size = batch_end - batch_start; 2183 const int features_size = cached_features.OutputFeaturesSize(); 2184 TensorView<float> logits = selection_executor_->ComputeLogits( 2185 TensorView<float>(all_features.data(), {batch_size, features_size}), 2186 selection_interpreter); 2187 if (!logits.is_valid()) { 2188 TC3_LOG(ERROR) << "Couldn't compute logits."; 2189 return false; 2190 } 2191 if (logits.dims() != 2 || logits.dim(0) != batch_size || 2192 logits.dim(1) != 2193 selection_feature_processor_->GetSelectionLabelCount()) { 2194 TC3_LOG(ERROR) << "Mismatching output."; 2195 return false; 2196 } 2197 2198 // Save results. 2199 for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) { 2200 const std::vector<float> scores = ComputeSoftmax( 2201 logits.data() + logits.dim(1) * (click_pos - batch_start), 2202 logits.dim(1)); 2203 for (int j = 0; 2204 j < selection_feature_processor_->GetSelectionLabelCount(); ++j) { 2205 TokenSpan relative_token_span; 2206 if (!selection_feature_processor_->LabelToTokenSpan( 2207 j, &relative_token_span)) { 2208 TC3_LOG(ERROR) << "Couldn't map the label to a token span."; 2209 return false; 2210 } 2211 const TokenSpan candidate_span = ExpandTokenSpan( 2212 SingleTokenSpan(click_pos), relative_token_span.first, 2213 relative_token_span.second); 2214 if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) { 2215 UpdateMax(&chunk_scores, candidate_span, scores[j]); 2216 } 2217 } 2218 } 2219 } 2220 2221 scored_chunks->clear(); 2222 scored_chunks->reserve(chunk_scores.size()); 2223 for (const auto& entry : chunk_scores) { 2224 scored_chunks->push_back(ScoredChunk{entry.first, entry.second}); 2225 } 2226 2227 return true; 2228 } 2229 2230 bool Annotator::ModelBoundsSensitiveScoreChunks( 2231 int num_tokens, const TokenSpan& span_of_interest, 2232 const TokenSpan& inference_span, const CachedFeatures& cached_features, 2233 tflite::Interpreter* selection_interpreter, 2234 std::vector<ScoredChunk>* scored_chunks) const { 2235 const int max_selection_span = 2236 selection_feature_processor_->GetOptions()->max_selection_span(); 2237 const int max_chunk_length = selection_feature_processor_->GetOptions() 2238 ->selection_reduced_output_space() 2239 ? max_selection_span + 1 2240 : 2 * max_selection_span + 1; 2241 const bool score_single_token_spans_as_zero = 2242 selection_feature_processor_->GetOptions() 2243 ->bounds_sensitive_features() 2244 ->score_single_token_spans_as_zero(); 2245 2246 scored_chunks->clear(); 2247 if (score_single_token_spans_as_zero) { 2248 scored_chunks->reserve(TokenSpanSize(span_of_interest)); 2249 } 2250 2251 // Prepare all chunk candidates into one batch: 2252 // - Are contained in the inference span 2253 // - Have a non-empty intersection with the span of interest 2254 // - Are at least one token long 2255 // - Are not longer than the maximum chunk length 2256 std::vector<TokenSpan> candidate_spans; 2257 for (int start = inference_span.first; start < span_of_interest.second; 2258 ++start) { 2259 const int leftmost_end_index = std::max(start, span_of_interest.first) + 1; 2260 for (int end = leftmost_end_index; 2261 end <= inference_span.second && end - start <= max_chunk_length; 2262 ++end) { 2263 const TokenSpan candidate_span = {start, end}; 2264 if (score_single_token_spans_as_zero && 2265 TokenSpanSize(candidate_span) == 1) { 2266 // Do not include the single token span in the batch, add a zero score 2267 // for it directly to the output. 2268 scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f}); 2269 } else { 2270 candidate_spans.push_back(candidate_span); 2271 } 2272 } 2273 } 2274 2275 const int max_batch_size = model_->selection_options()->batch_size(); 2276 2277 std::vector<float> all_features; 2278 scored_chunks->reserve(scored_chunks->size() + candidate_spans.size()); 2279 for (int batch_start = 0; batch_start < candidate_spans.size(); 2280 batch_start += max_batch_size) { 2281 const int batch_end = std::min(batch_start + max_batch_size, 2282 static_cast<int>(candidate_spans.size())); 2283 2284 // Prepare features for the whole batch. 2285 all_features.clear(); 2286 all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize()); 2287 for (int i = batch_start; i < batch_end; ++i) { 2288 cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i], 2289 &all_features); 2290 } 2291 2292 // Run batched inference. 2293 const int batch_size = batch_end - batch_start; 2294 const int features_size = cached_features.OutputFeaturesSize(); 2295 TensorView<float> logits = selection_executor_->ComputeLogits( 2296 TensorView<float>(all_features.data(), {batch_size, features_size}), 2297 selection_interpreter); 2298 if (!logits.is_valid()) { 2299 TC3_LOG(ERROR) << "Couldn't compute logits."; 2300 return false; 2301 } 2302 if (logits.dims() != 2 || logits.dim(0) != batch_size || 2303 logits.dim(1) != 1) { 2304 TC3_LOG(ERROR) << "Mismatching output."; 2305 return false; 2306 } 2307 2308 // Save results. 2309 for (int i = batch_start; i < batch_end; ++i) { 2310 scored_chunks->push_back( 2311 ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]}); 2312 } 2313 } 2314 2315 return true; 2316 } 2317 2318 bool Annotator::DatetimeChunk(const UnicodeText& context_unicode, 2319 int64 reference_time_ms_utc, 2320 const std::string& reference_timezone, 2321 const std::string& locales, ModeFlag mode, 2322 AnnotationUsecase annotation_usecase, 2323 bool is_serialized_entity_data_enabled, 2324 std::vector<AnnotatedSpan>* result) const { 2325 if (!datetime_parser_) { 2326 return true; 2327 } 2328 2329 std::vector<DatetimeParseResultSpan> datetime_spans; 2330 if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc, 2331 reference_timezone, locales, mode, 2332 annotation_usecase, 2333 /*anchor_start_end=*/false, &datetime_spans)) { 2334 return false; 2335 } 2336 for (const DatetimeParseResultSpan& datetime_span : datetime_spans) { 2337 AnnotatedSpan annotated_span; 2338 annotated_span.span = datetime_span.span; 2339 for (const DatetimeParseResult& parse_result : datetime_span.data) { 2340 annotated_span.classification.emplace_back( 2341 PickCollectionForDatetime(parse_result), 2342 datetime_span.target_classification_score, 2343 datetime_span.priority_score); 2344 annotated_span.classification.back().datetime_parse_result = parse_result; 2345 if (is_serialized_entity_data_enabled) { 2346 annotated_span.classification.back().serialized_entity_data = 2347 CreateDatetimeSerializedEntityData(parse_result); 2348 } 2349 } 2350 annotated_span.source = AnnotatedSpan::Source::DATETIME; 2351 result->push_back(std::move(annotated_span)); 2352 } 2353 return true; 2354 } 2355 2356 const Model* Annotator::model() const { return model_; } 2357 const reflection::Schema* Annotator::entity_data_schema() const { 2358 return entity_data_schema_; 2359 } 2360 2361 const Model* ViewModel(const void* buffer, int size) { 2362 if (!buffer) { 2363 return nullptr; 2364 } 2365 2366 return LoadAndVerifyModel(buffer, size); 2367 } 2368 2369 bool Annotator::LookUpKnowledgeEntity( 2370 const std::string& id, std::string* serialized_knowledge_result) const { 2371 return knowledge_engine_ && 2372 knowledge_engine_->LookUpEntity(id, serialized_knowledge_result); 2373 } 2374 2375 } // namespace libtextclassifier3 2376