Home | History | Annotate | Download | only in annotator
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "annotator/annotator.h"
     18 
     19 #include <algorithm>
     20 #include <cctype>
     21 #include <cmath>
     22 #include <iterator>
     23 #include <numeric>
     24 #include <unordered_map>
     25 
     26 #include "annotator/collections.h"
     27 #include "annotator/model_generated.h"
     28 #include "annotator/types.h"
     29 #include "utils/base/logging.h"
     30 #include "utils/checksum.h"
     31 #include "utils/math/softmax.h"
     32 #include "utils/regex-match.h"
     33 #include "utils/utf8/unicodetext.h"
     34 #include "utils/zlib/zlib_regex.h"
     35 
     36 
     37 namespace libtextclassifier3 {
     38 
     39 using SortedIntSet = std::set<int, std::function<bool(int, int)>>;
     40 
     41 const std::string& Annotator::kPhoneCollection =
     42     *[]() { return new std::string("phone"); }();
     43 const std::string& Annotator::kAddressCollection =
     44     *[]() { return new std::string("address"); }();
     45 const std::string& Annotator::kDateCollection =
     46     *[]() { return new std::string("date"); }();
     47 const std::string& Annotator::kUrlCollection =
     48     *[]() { return new std::string("url"); }();
     49 const std::string& Annotator::kEmailCollection =
     50     *[]() { return new std::string("email"); }();
     51 
     52 namespace {
     53 const Model* LoadAndVerifyModel(const void* addr, int size) {
     54   flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
     55   if (VerifyModelBuffer(verifier)) {
     56     return GetModel(addr);
     57   } else {
     58     return nullptr;
     59   }
     60 }
     61 
     62 // If lib is not nullptr, just returns lib. Otherwise, if lib is nullptr, will
     63 // create a new instance, assign ownership to owned_lib, and return it.
     64 const UniLib* MaybeCreateUnilib(const UniLib* lib,
     65                                 std::unique_ptr<UniLib>* owned_lib) {
     66   if (lib) {
     67     return lib;
     68   } else {
     69     owned_lib->reset(new UniLib);
     70     return owned_lib->get();
     71   }
     72 }
     73 
     74 // As above, but for CalendarLib.
     75 const CalendarLib* MaybeCreateCalendarlib(
     76     const CalendarLib* lib, std::unique_ptr<CalendarLib>* owned_lib) {
     77   if (lib) {
     78     return lib;
     79   } else {
     80     owned_lib->reset(new CalendarLib);
     81     return owned_lib->get();
     82   }
     83 }
     84 
     85 }  // namespace
     86 
     87 tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
     88   if (!selection_interpreter_) {
     89     TC3_CHECK(selection_executor_);
     90     selection_interpreter_ = selection_executor_->CreateInterpreter();
     91     if (!selection_interpreter_) {
     92       TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
     93     }
     94   }
     95   return selection_interpreter_.get();
     96 }
     97 
     98 tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
     99   if (!classification_interpreter_) {
    100     TC3_CHECK(classification_executor_);
    101     classification_interpreter_ = classification_executor_->CreateInterpreter();
    102     if (!classification_interpreter_) {
    103       TC3_LOG(ERROR) << "Could not build TFLite interpreter.";
    104     }
    105   }
    106   return classification_interpreter_.get();
    107 }
    108 
    109 std::unique_ptr<Annotator> Annotator::FromUnownedBuffer(
    110     const char* buffer, int size, const UniLib* unilib,
    111     const CalendarLib* calendarlib) {
    112   const Model* model = LoadAndVerifyModel(buffer, size);
    113   if (model == nullptr) {
    114     return nullptr;
    115   }
    116 
    117   auto classifier =
    118       std::unique_ptr<Annotator>(new Annotator(model, unilib, calendarlib));
    119   if (!classifier->IsInitialized()) {
    120     return nullptr;
    121   }
    122 
    123   return classifier;
    124 }
    125 
    126 
    127 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
    128     std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib,
    129     const CalendarLib* calendarlib) {
    130   if (!(*mmap)->handle().ok()) {
    131     TC3_VLOG(1) << "Mmap failed.";
    132     return nullptr;
    133   }
    134 
    135   const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
    136                                           (*mmap)->handle().num_bytes());
    137   if (!model) {
    138     TC3_LOG(ERROR) << "Model verification failed.";
    139     return nullptr;
    140   }
    141 
    142   auto classifier = std::unique_ptr<Annotator>(
    143       new Annotator(mmap, model, unilib, calendarlib));
    144   if (!classifier->IsInitialized()) {
    145     return nullptr;
    146   }
    147 
    148   return classifier;
    149 }
    150 
    151 std::unique_ptr<Annotator> Annotator::FromScopedMmap(
    152     std::unique_ptr<ScopedMmap>* mmap, std::unique_ptr<UniLib> unilib,
    153     std::unique_ptr<CalendarLib> calendarlib) {
    154   if (!(*mmap)->handle().ok()) {
    155     TC3_VLOG(1) << "Mmap failed.";
    156     return nullptr;
    157   }
    158 
    159   const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
    160                                           (*mmap)->handle().num_bytes());
    161   if (model == nullptr) {
    162     TC3_LOG(ERROR) << "Model verification failed.";
    163     return nullptr;
    164   }
    165 
    166   auto classifier = std::unique_ptr<Annotator>(
    167       new Annotator(mmap, model, std::move(unilib), std::move(calendarlib)));
    168   if (!classifier->IsInitialized()) {
    169     return nullptr;
    170   }
    171 
    172   return classifier;
    173 }
    174 
    175 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
    176     int fd, int offset, int size, const UniLib* unilib,
    177     const CalendarLib* calendarlib) {
    178   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
    179   return FromScopedMmap(&mmap, unilib, calendarlib);
    180 }
    181 
    182 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
    183     int fd, int offset, int size, std::unique_ptr<UniLib> unilib,
    184     std::unique_ptr<CalendarLib> calendarlib) {
    185   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
    186   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
    187 }
    188 
    189 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
    190     int fd, const UniLib* unilib, const CalendarLib* calendarlib) {
    191   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
    192   return FromScopedMmap(&mmap, unilib, calendarlib);
    193 }
    194 
    195 std::unique_ptr<Annotator> Annotator::FromFileDescriptor(
    196     int fd, std::unique_ptr<UniLib> unilib,
    197     std::unique_ptr<CalendarLib> calendarlib) {
    198   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
    199   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
    200 }
    201 
    202 std::unique_ptr<Annotator> Annotator::FromPath(const std::string& path,
    203                                                const UniLib* unilib,
    204                                                const CalendarLib* calendarlib) {
    205   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
    206   return FromScopedMmap(&mmap, unilib, calendarlib);
    207 }
    208 
    209 std::unique_ptr<Annotator> Annotator::FromPath(
    210     const std::string& path, std::unique_ptr<UniLib> unilib,
    211     std::unique_ptr<CalendarLib> calendarlib) {
    212   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
    213   return FromScopedMmap(&mmap, std::move(unilib), std::move(calendarlib));
    214 }
    215 
    216 Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
    217                      const UniLib* unilib, const CalendarLib* calendarlib)
    218     : model_(model),
    219       mmap_(std::move(*mmap)),
    220       owned_unilib_(nullptr),
    221       unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
    222       owned_calendarlib_(nullptr),
    223       calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
    224   ValidateAndInitialize();
    225 }
    226 
    227 Annotator::Annotator(std::unique_ptr<ScopedMmap>* mmap, const Model* model,
    228                      std::unique_ptr<UniLib> unilib,
    229                      std::unique_ptr<CalendarLib> calendarlib)
    230     : model_(model),
    231       mmap_(std::move(*mmap)),
    232       owned_unilib_(std::move(unilib)),
    233       unilib_(owned_unilib_.get()),
    234       owned_calendarlib_(std::move(calendarlib)),
    235       calendarlib_(owned_calendarlib_.get()) {
    236   ValidateAndInitialize();
    237 }
    238 
    239 Annotator::Annotator(const Model* model, const UniLib* unilib,
    240                      const CalendarLib* calendarlib)
    241     : model_(model),
    242       owned_unilib_(nullptr),
    243       unilib_(MaybeCreateUnilib(unilib, &owned_unilib_)),
    244       owned_calendarlib_(nullptr),
    245       calendarlib_(MaybeCreateCalendarlib(calendarlib, &owned_calendarlib_)) {
    246   ValidateAndInitialize();
    247 }
    248 
    249 void Annotator::ValidateAndInitialize() {
    250   initialized_ = false;
    251 
    252   if (model_ == nullptr) {
    253     TC3_LOG(ERROR) << "No model specified.";
    254     return;
    255   }
    256 
    257   const bool model_enabled_for_annotation =
    258       (model_->triggering_options() != nullptr &&
    259        (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
    260   const bool model_enabled_for_classification =
    261       (model_->triggering_options() != nullptr &&
    262        (model_->triggering_options()->enabled_modes() &
    263         ModeFlag_CLASSIFICATION));
    264   const bool model_enabled_for_selection =
    265       (model_->triggering_options() != nullptr &&
    266        (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
    267 
    268   // Annotation requires the selection model.
    269   if (model_enabled_for_annotation || model_enabled_for_selection) {
    270     if (!model_->selection_options()) {
    271       TC3_LOG(ERROR) << "No selection options.";
    272       return;
    273     }
    274     if (!model_->selection_feature_options()) {
    275       TC3_LOG(ERROR) << "No selection feature options.";
    276       return;
    277     }
    278     if (!model_->selection_feature_options()->bounds_sensitive_features()) {
    279       TC3_LOG(ERROR) << "No selection bounds sensitive feature options.";
    280       return;
    281     }
    282     if (!model_->selection_model()) {
    283       TC3_LOG(ERROR) << "No selection model.";
    284       return;
    285     }
    286     selection_executor_ = ModelExecutor::FromBuffer(model_->selection_model());
    287     if (!selection_executor_) {
    288       TC3_LOG(ERROR) << "Could not initialize selection executor.";
    289       return;
    290     }
    291     selection_feature_processor_.reset(
    292         new FeatureProcessor(model_->selection_feature_options(), unilib_));
    293   }
    294 
    295   // Annotation requires the classification model for conflict resolution and
    296   // scoring.
    297   // Selection requires the classification model for conflict resolution.
    298   if (model_enabled_for_annotation || model_enabled_for_classification ||
    299       model_enabled_for_selection) {
    300     if (!model_->classification_options()) {
    301       TC3_LOG(ERROR) << "No classification options.";
    302       return;
    303     }
    304 
    305     if (!model_->classification_feature_options()) {
    306       TC3_LOG(ERROR) << "No classification feature options.";
    307       return;
    308     }
    309 
    310     if (!model_->classification_feature_options()
    311              ->bounds_sensitive_features()) {
    312       TC3_LOG(ERROR) << "No classification bounds sensitive feature options.";
    313       return;
    314     }
    315     if (!model_->classification_model()) {
    316       TC3_LOG(ERROR) << "No clf model.";
    317       return;
    318     }
    319 
    320     classification_executor_ =
    321         ModelExecutor::FromBuffer(model_->classification_model());
    322     if (!classification_executor_) {
    323       TC3_LOG(ERROR) << "Could not initialize classification executor.";
    324       return;
    325     }
    326 
    327     classification_feature_processor_.reset(new FeatureProcessor(
    328         model_->classification_feature_options(), unilib_));
    329   }
    330 
    331   // The embeddings need to be specified if the model is to be used for
    332   // classification or selection.
    333   if (model_enabled_for_annotation || model_enabled_for_classification ||
    334       model_enabled_for_selection) {
    335     if (!model_->embedding_model()) {
    336       TC3_LOG(ERROR) << "No embedding model.";
    337       return;
    338     }
    339 
    340     // Check that the embedding size of the selection and classification model
    341     // matches, as they are using the same embeddings.
    342     if (model_enabled_for_selection &&
    343         (model_->selection_feature_options()->embedding_size() !=
    344              model_->classification_feature_options()->embedding_size() ||
    345          model_->selection_feature_options()->embedding_quantization_bits() !=
    346              model_->classification_feature_options()
    347                  ->embedding_quantization_bits())) {
    348       TC3_LOG(ERROR) << "Mismatching embedding size/quantization.";
    349       return;
    350     }
    351 
    352     embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
    353         model_->embedding_model(),
    354         model_->classification_feature_options()->embedding_size(),
    355         model_->classification_feature_options()->embedding_quantization_bits(),
    356         model_->embedding_pruning_mask());
    357     if (!embedding_executor_) {
    358       TC3_LOG(ERROR) << "Could not initialize embedding executor.";
    359       return;
    360     }
    361   }
    362 
    363   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
    364   if (model_->regex_model()) {
    365     if (!InitializeRegexModel(decompressor.get())) {
    366       TC3_LOG(ERROR) << "Could not initialize regex model.";
    367       return;
    368     }
    369   }
    370 
    371   if (model_->datetime_model()) {
    372     datetime_parser_ = DatetimeParser::Instance(
    373         model_->datetime_model(), *unilib_, *calendarlib_, decompressor.get());
    374     if (!datetime_parser_) {
    375       TC3_LOG(ERROR) << "Could not initialize datetime parser.";
    376       return;
    377     }
    378   }
    379 
    380   if (model_->output_options()) {
    381     if (model_->output_options()->filtered_collections_annotation()) {
    382       for (const auto collection :
    383            *model_->output_options()->filtered_collections_annotation()) {
    384         filtered_collections_annotation_.insert(collection->str());
    385       }
    386     }
    387     if (model_->output_options()->filtered_collections_classification()) {
    388       for (const auto collection :
    389            *model_->output_options()->filtered_collections_classification()) {
    390         filtered_collections_classification_.insert(collection->str());
    391       }
    392     }
    393     if (model_->output_options()->filtered_collections_selection()) {
    394       for (const auto collection :
    395            *model_->output_options()->filtered_collections_selection()) {
    396         filtered_collections_selection_.insert(collection->str());
    397       }
    398     }
    399   }
    400 
    401   if (model_->number_annotator_options() &&
    402       model_->number_annotator_options()->enabled()) {
    403     if (selection_feature_processor_ == nullptr) {
    404       TC3_LOG(ERROR)
    405           << "Could not initialize NumberAnnotator without a feature processor";
    406       return;
    407     }
    408 
    409     number_annotator_.reset(
    410         new NumberAnnotator(model_->number_annotator_options(),
    411                             selection_feature_processor_.get()));
    412   }
    413 
    414   if (model_->duration_annotator_options() &&
    415       model_->duration_annotator_options()->enabled()) {
    416     duration_annotator_.reset(
    417         new DurationAnnotator(model_->duration_annotator_options(),
    418                               selection_feature_processor_.get()));
    419   }
    420 
    421   if (model_->entity_data_schema()) {
    422     entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
    423         model_->entity_data_schema()->Data(),
    424         model_->entity_data_schema()->size());
    425     if (entity_data_schema_ == nullptr) {
    426       TC3_LOG(ERROR) << "Could not load entity data schema data.";
    427       return;
    428     }
    429 
    430     entity_data_builder_.reset(
    431         new ReflectiveFlatbufferBuilder(entity_data_schema_));
    432   } else {
    433     entity_data_schema_ = nullptr;
    434     entity_data_builder_ = nullptr;
    435   }
    436 
    437   if (model_->triggering_locales() &&
    438       !ParseLocales(model_->triggering_locales()->c_str(),
    439                     &model_triggering_locales_)) {
    440     TC3_LOG(ERROR) << "Could not parse model supported locales.";
    441     return;
    442   }
    443 
    444   if (model_->triggering_options() != nullptr &&
    445       model_->triggering_options()->locales() != nullptr &&
    446       !ParseLocales(model_->triggering_options()->locales()->c_str(),
    447                     &ml_model_triggering_locales_)) {
    448     TC3_LOG(ERROR) << "Could not parse supported ML model locales.";
    449     return;
    450   }
    451 
    452   if (model_->triggering_options() != nullptr &&
    453       model_->triggering_options()->dictionary_locales() != nullptr &&
    454       !ParseLocales(model_->triggering_options()->dictionary_locales()->c_str(),
    455                     &dictionary_locales_)) {
    456     TC3_LOG(ERROR) << "Could not parse dictionary supported locales.";
    457     return;
    458   }
    459 
    460   initialized_ = true;
    461 }
    462 
    463 bool Annotator::InitializeRegexModel(ZlibDecompressor* decompressor) {
    464   if (!model_->regex_model()->patterns()) {
    465     return true;
    466   }
    467 
    468   // Initialize pattern recognizers.
    469   int regex_pattern_id = 0;
    470   for (const auto& regex_pattern : *model_->regex_model()->patterns()) {
    471     std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
    472         UncompressMakeRegexPattern(
    473             *unilib_, regex_pattern->pattern(),
    474             regex_pattern->compressed_pattern(),
    475             model_->regex_model()->lazy_regex_compilation(), decompressor);
    476     if (!compiled_pattern) {
    477       TC3_LOG(INFO) << "Failed to load regex pattern";
    478       return false;
    479     }
    480 
    481     if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
    482       annotation_regex_patterns_.push_back(regex_pattern_id);
    483     }
    484     if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
    485       classification_regex_patterns_.push_back(regex_pattern_id);
    486     }
    487     if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
    488       selection_regex_patterns_.push_back(regex_pattern_id);
    489     }
    490     regex_patterns_.push_back({
    491         regex_pattern,
    492         std::move(compiled_pattern),
    493     });
    494     ++regex_pattern_id;
    495   }
    496 
    497   return true;
    498 }
    499 
    500 bool Annotator::InitializeKnowledgeEngine(
    501     const std::string& serialized_config) {
    502   std::unique_ptr<KnowledgeEngine> knowledge_engine(
    503       new KnowledgeEngine(unilib_));
    504   if (!knowledge_engine->Initialize(serialized_config)) {
    505     TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
    506     return false;
    507   }
    508   knowledge_engine_ = std::move(knowledge_engine);
    509   return true;
    510 }
    511 
    512 bool Annotator::InitializeContactEngine(const std::string& serialized_config) {
    513   std::unique_ptr<ContactEngine> contact_engine(
    514       new ContactEngine(selection_feature_processor_.get(), unilib_));
    515   if (!contact_engine->Initialize(serialized_config)) {
    516     TC3_LOG(ERROR) << "Failed to initialize the contact engine.";
    517     return false;
    518   }
    519   contact_engine_ = std::move(contact_engine);
    520   return true;
    521 }
    522 
    523 bool Annotator::InitializeInstalledAppEngine(
    524     const std::string& serialized_config) {
    525   std::unique_ptr<InstalledAppEngine> installed_app_engine(
    526       new InstalledAppEngine(selection_feature_processor_.get(), unilib_));
    527   if (!installed_app_engine->Initialize(serialized_config)) {
    528     TC3_LOG(ERROR) << "Failed to initialize the installed app engine.";
    529     return false;
    530   }
    531   installed_app_engine_ = std::move(installed_app_engine);
    532   return true;
    533 }
    534 
    535 namespace {
    536 
    537 int CountDigits(const std::string& str, CodepointSpan selection_indices) {
    538   int count = 0;
    539   int i = 0;
    540   const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
    541   for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
    542     if (i >= selection_indices.first && i < selection_indices.second &&
    543         isdigit(*it)) {
    544       ++count;
    545     }
    546   }
    547   return count;
    548 }
    549 
    550 }  // namespace
    551 
    552 namespace internal {
    553 // Helper function, which if the initial 'span' contains only white-spaces,
    554 // moves the selection to a single-codepoint selection on a left or right side
    555 // of this space.
    556 CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
    557                                             const UnicodeText& context_unicode,
    558                                             const UniLib& unilib) {
    559   TC3_CHECK(ValidNonEmptySpan(span));
    560 
    561   UnicodeText::const_iterator it;
    562 
    563   // Check that the current selection is all whitespaces.
    564   it = context_unicode.begin();
    565   std::advance(it, span.first);
    566   for (int i = 0; i < (span.second - span.first); ++i, ++it) {
    567     if (!unilib.IsWhitespace(*it)) {
    568       return span;
    569     }
    570   }
    571 
    572   CodepointSpan result;
    573 
    574   // Try moving left.
    575   result = span;
    576   it = context_unicode.begin();
    577   std::advance(it, span.first);
    578   while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
    579     --result.first;
    580     --it;
    581   }
    582   result.second = result.first + 1;
    583   if (!unilib.IsWhitespace(*it)) {
    584     return result;
    585   }
    586 
    587   // If moving left didn't find a non-whitespace character, just return the
    588   // original span.
    589   return span;
    590 }
    591 }  // namespace internal
    592 
    593 bool Annotator::FilteredForAnnotation(const AnnotatedSpan& span) const {
    594   return !span.classification.empty() &&
    595          filtered_collections_annotation_.find(
    596              span.classification[0].collection) !=
    597              filtered_collections_annotation_.end();
    598 }
    599 
    600 bool Annotator::FilteredForClassification(
    601     const ClassificationResult& classification) const {
    602   return filtered_collections_classification_.find(classification.collection) !=
    603          filtered_collections_classification_.end();
    604 }
    605 
    606 bool Annotator::FilteredForSelection(const AnnotatedSpan& span) const {
    607   return !span.classification.empty() &&
    608          filtered_collections_selection_.find(
    609              span.classification[0].collection) !=
    610              filtered_collections_selection_.end();
    611 }
    612 
    613 namespace {
    614 inline bool ClassifiedAsOther(
    615     const std::vector<ClassificationResult>& classification) {
    616   return !classification.empty() &&
    617          classification[0].collection == Collections::Other();
    618 }
    619 
    620 float GetPriorityScore(
    621     const std::vector<ClassificationResult>& classification) {
    622   if (!classification.empty() && !ClassifiedAsOther(classification)) {
    623     return classification[0].priority_score;
    624   } else {
    625     return -1.0;
    626   }
    627 }
    628 }  // namespace
    629 
    630 bool Annotator::VerifyRegexMatchCandidate(
    631     const std::string& context, const VerificationOptions* verification_options,
    632     const std::string& match, const UniLib::RegexMatcher* matcher) const {
    633   if (verification_options == nullptr) {
    634     return true;
    635   }
    636   if (verification_options->verify_luhn_checksum() &&
    637       !VerifyLuhnChecksum(match)) {
    638     return false;
    639   }
    640   const int lua_verifier = verification_options->lua_verifier();
    641   if (lua_verifier >= 0) {
    642     if (model_->regex_model()->lua_verifier() == nullptr ||
    643         lua_verifier >= model_->regex_model()->lua_verifier()->size()) {
    644       TC3_LOG(ERROR) << "Invalid lua verifier specified: " << lua_verifier;
    645       return false;
    646     }
    647     return VerifyMatch(
    648         context, matcher,
    649         model_->regex_model()->lua_verifier()->Get(lua_verifier)->str());
    650   }
    651   return true;
    652 }
    653 
    654 CodepointSpan Annotator::SuggestSelection(
    655     const std::string& context, CodepointSpan click_indices,
    656     const SelectionOptions& options) const {
    657   CodepointSpan original_click_indices = click_indices;
    658   if (!initialized_) {
    659     TC3_LOG(ERROR) << "Not initialized";
    660     return original_click_indices;
    661   }
    662   if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
    663     return original_click_indices;
    664   }
    665 
    666   std::vector<Locale> detected_text_language_tags;
    667   if (!ParseLocales(options.detected_text_language_tags,
    668                     &detected_text_language_tags)) {
    669     TC3_LOG(WARNING)
    670         << "Failed to parse the detected_text_language_tags in options: "
    671         << options.detected_text_language_tags;
    672   }
    673   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
    674                                     model_triggering_locales_,
    675                                     /*default_value=*/true)) {
    676     return original_click_indices;
    677   }
    678 
    679   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
    680                                                         /*do_copy=*/false);
    681 
    682   if (!context_unicode.is_valid()) {
    683     return original_click_indices;
    684   }
    685 
    686   const int context_codepoint_size = context_unicode.size_codepoints();
    687 
    688   if (click_indices.first < 0 || click_indices.second < 0 ||
    689       click_indices.first >= context_codepoint_size ||
    690       click_indices.second > context_codepoint_size ||
    691       click_indices.first >= click_indices.second) {
    692     TC3_VLOG(1) << "Trying to run SuggestSelection with invalid indices: "
    693                 << click_indices.first << " " << click_indices.second;
    694     return original_click_indices;
    695   }
    696 
    697   if (model_->snap_whitespace_selections()) {
    698     // We want to expand a purely white-space selection to a multi-selection it
    699     // would've been part of. But with this feature disabled we would do a no-
    700     // op, because no token is found. Therefore, we need to modify the
    701     // 'click_indices' a bit to include a part of the token, so that the click-
    702     // finding logic finds the clicked token correctly. This modification is
    703     // done by the following function. Note, that it's enough to check the left
    704     // side of the current selection, because if the white-space is a part of a
    705     // multi-selection, necessarily both tokens - on the left and the right
    706     // sides need to be selected. Thus snapping only to the left is sufficient
    707     // (there's a check at the bottom that makes sure that if we snap to the
    708     // left token but the result does not contain the initial white-space,
    709     // returns the original indices).
    710     click_indices = internal::SnapLeftIfWhitespaceSelection(
    711         click_indices, context_unicode, *unilib_);
    712   }
    713 
    714   std::vector<AnnotatedSpan> candidates;
    715   InterpreterManager interpreter_manager(selection_executor_.get(),
    716                                          classification_executor_.get());
    717   std::vector<Token> tokens;
    718   if (!ModelSuggestSelection(context_unicode, click_indices,
    719                              detected_text_language_tags, &interpreter_manager,
    720                              &tokens, &candidates)) {
    721     TC3_LOG(ERROR) << "Model suggest selection failed.";
    722     return original_click_indices;
    723   }
    724   if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates,
    725                   /*is_serialized_entity_data_enabled=*/false)) {
    726     TC3_LOG(ERROR) << "Regex suggest selection failed.";
    727     return original_click_indices;
    728   }
    729   if (!DatetimeChunk(
    730           UTF8ToUnicodeText(context, /*do_copy=*/false),
    731           /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
    732           options.locales, ModeFlag_SELECTION, options.annotation_usecase,
    733           /*is_serialized_entity_data_enabled=*/false, &candidates)) {
    734     TC3_LOG(ERROR) << "Datetime suggest selection failed.";
    735     return original_click_indices;
    736   }
    737   if (knowledge_engine_ != nullptr &&
    738       !knowledge_engine_->Chunk(context, &candidates)) {
    739     TC3_LOG(ERROR) << "Knowledge suggest selection failed.";
    740     return original_click_indices;
    741   }
    742   if (contact_engine_ != nullptr &&
    743       !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
    744     TC3_LOG(ERROR) << "Contact suggest selection failed.";
    745     return original_click_indices;
    746   }
    747   if (installed_app_engine_ != nullptr &&
    748       !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
    749     TC3_LOG(ERROR) << "Installed app suggest selection failed.";
    750     return original_click_indices;
    751   }
    752   if (number_annotator_ != nullptr &&
    753       !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
    754                                   &candidates)) {
    755     TC3_LOG(ERROR) << "Number annotator failed in suggest selection.";
    756     return original_click_indices;
    757   }
    758   if (duration_annotator_ != nullptr &&
    759       !duration_annotator_->FindAll(context_unicode, tokens,
    760                                     options.annotation_usecase, &candidates)) {
    761     TC3_LOG(ERROR) << "Duration annotator failed in suggest selection.";
    762     return original_click_indices;
    763   }
    764 
    765   // Sort candidates according to their position in the input, so that the next
    766   // code can assume that any connected component of overlapping spans forms a
    767   // contiguous block.
    768   std::sort(candidates.begin(), candidates.end(),
    769             [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
    770               return a.span.first < b.span.first;
    771             });
    772 
    773   std::vector<int> candidate_indices;
    774   if (!ResolveConflicts(candidates, context, tokens,
    775                         detected_text_language_tags, options.annotation_usecase,
    776                         &interpreter_manager, &candidate_indices)) {
    777     TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
    778     return original_click_indices;
    779   }
    780 
    781   std::sort(candidate_indices.begin(), candidate_indices.end(),
    782             [&candidates](int a, int b) {
    783               return GetPriorityScore(candidates[a].classification) >
    784                      GetPriorityScore(candidates[b].classification);
    785             });
    786 
    787   for (const int i : candidate_indices) {
    788     if (SpansOverlap(candidates[i].span, click_indices) &&
    789         SpansOverlap(candidates[i].span, original_click_indices)) {
    790       // Run model classification if not present but requested and there's a
    791       // classification collection filter specified.
    792       if (candidates[i].classification.empty() &&
    793           model_->selection_options()->always_classify_suggested_selection() &&
    794           !filtered_collections_selection_.empty()) {
    795         if (!ModelClassifyText(context, detected_text_language_tags,
    796                                candidates[i].span, &interpreter_manager,
    797                                /*embedding_cache=*/nullptr,
    798                                &candidates[i].classification)) {
    799           return original_click_indices;
    800         }
    801       }
    802 
    803       // Ignore if span classification is filtered.
    804       if (FilteredForSelection(candidates[i])) {
    805         return original_click_indices;
    806       }
    807 
    808       return candidates[i].span;
    809     }
    810   }
    811 
    812   return original_click_indices;
    813 }
    814 
    815 namespace {
    816 // Helper function that returns the index of the first candidate that
    817 // transitively does not overlap with the candidate on 'start_index'. If the end
    818 // of 'candidates' is reached, it returns the index that points right behind the
    819 // array.
    820 int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
    821                                  int start_index) {
    822   int first_non_overlapping = start_index + 1;
    823   CodepointSpan conflicting_span = candidates[start_index].span;
    824   while (
    825       first_non_overlapping < candidates.size() &&
    826       SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
    827     // Grow the span to include the current one.
    828     conflicting_span.second = std::max(
    829         conflicting_span.second, candidates[first_non_overlapping].span.second);
    830 
    831     ++first_non_overlapping;
    832   }
    833   return first_non_overlapping;
    834 }
    835 }  // namespace
    836 
    837 bool Annotator::ResolveConflicts(
    838     const std::vector<AnnotatedSpan>& candidates, const std::string& context,
    839     const std::vector<Token>& cached_tokens,
    840     const std::vector<Locale>& detected_text_language_tags,
    841     AnnotationUsecase annotation_usecase,
    842     InterpreterManager* interpreter_manager, std::vector<int>* result) const {
    843   result->clear();
    844   result->reserve(candidates.size());
    845   for (int i = 0; i < candidates.size();) {
    846     int first_non_overlapping =
    847         FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
    848 
    849     const bool conflict_found = first_non_overlapping != (i + 1);
    850     if (conflict_found) {
    851       std::vector<int> candidate_indices;
    852       if (!ResolveConflict(context, cached_tokens, candidates,
    853                            detected_text_language_tags, i,
    854                            first_non_overlapping, annotation_usecase,
    855                            interpreter_manager, &candidate_indices)) {
    856         return false;
    857       }
    858       result->insert(result->end(), candidate_indices.begin(),
    859                      candidate_indices.end());
    860     } else {
    861       result->push_back(i);
    862     }
    863 
    864     // Skip over the whole conflicting group/go to next candidate.
    865     i = first_non_overlapping;
    866   }
    867   return true;
    868 }
    869 
    870 namespace {
    871 // Returns true, if the given two sources do conflict in given annotation
    872 // usecase.
    873 //  - In SMART usecase, all sources do conflict, because there's only 1 possible
    874 //  annotation for a given span.
    875 //  - In RAW usecase, certain annotations are allowed to overlap (e.g. datetime
    876 //  and duration), while others not (e.g. duration and number).
    877 bool DoSourcesConflict(AnnotationUsecase annotation_usecase,
    878                        const AnnotatedSpan::Source source1,
    879                        const AnnotatedSpan::Source source2) {
    880   uint32 source_mask =
    881       (1 << static_cast<int>(source1)) | (1 << static_cast<int>(source2));
    882 
    883   switch (annotation_usecase) {
    884     case AnnotationUsecase_ANNOTATION_USECASE_SMART:
    885       // In the SMART mode, all annotations conflict.
    886       return true;
    887 
    888     case AnnotationUsecase_ANNOTATION_USECASE_RAW:
    889       // DURATION and DATETIME do not conflict. E.g. "let's meet in 3 hours",
    890       // can have two non-conflicting annotations: "in 3 hours" (datetime), "3
    891       // hours" (duration).
    892       if ((source_mask &
    893            (1 << static_cast<int>(AnnotatedSpan::Source::DURATION))) &&
    894           (source_mask &
    895            (1 << static_cast<int>(AnnotatedSpan::Source::DATETIME)))) {
    896         return false;
    897       }
    898 
    899       // A KNOWLEDGE entity does not conflict with anything.
    900       if ((source_mask &
    901            (1 << static_cast<int>(AnnotatedSpan::Source::KNOWLEDGE)))) {
    902         return false;
    903       }
    904 
    905       // Entities from other sources can conflict.
    906       return true;
    907   }
    908 }
    909 }  // namespace
    910 
    911 bool Annotator::ResolveConflict(
    912     const std::string& context, const std::vector<Token>& cached_tokens,
    913     const std::vector<AnnotatedSpan>& candidates,
    914     const std::vector<Locale>& detected_text_language_tags, int start_index,
    915     int end_index, AnnotationUsecase annotation_usecase,
    916     InterpreterManager* interpreter_manager,
    917     std::vector<int>* chosen_indices) const {
    918   std::vector<int> conflicting_indices;
    919   std::unordered_map<int, float> scores;
    920   for (int i = start_index; i < end_index; ++i) {
    921     conflicting_indices.push_back(i);
    922     if (!candidates[i].classification.empty()) {
    923       scores[i] = GetPriorityScore(candidates[i].classification);
    924       continue;
    925     }
    926 
    927     // OPTIMIZATION: So that we don't have to classify all the ML model
    928     // spans apriori, we wait until we get here, when they conflict with
    929     // something and we need the actual classification scores. So if the
    930     // candidate conflicts and comes from the model, we need to run a
    931     // classification to determine its priority:
    932     std::vector<ClassificationResult> classification;
    933     if (!ModelClassifyText(context, cached_tokens, detected_text_language_tags,
    934                            candidates[i].span, interpreter_manager,
    935                            /*embedding_cache=*/nullptr, &classification)) {
    936       return false;
    937     }
    938 
    939     if (!classification.empty()) {
    940       scores[i] = GetPriorityScore(classification);
    941     }
    942   }
    943 
    944   std::sort(conflicting_indices.begin(), conflicting_indices.end(),
    945             [&scores](int i, int j) { return scores[i] > scores[j]; });
    946 
    947   // Here we keep a set of indices that were chosen, per-source, to enable
    948   // effective computation.
    949   std::unordered_map<AnnotatedSpan::Source, SortedIntSet>
    950       chosen_indices_for_source_map;
    951 
    952   // Greedily place the candidates if they don't conflict with the already
    953   // placed ones.
    954   for (int i = 0; i < conflicting_indices.size(); ++i) {
    955     const int considered_candidate = conflicting_indices[i];
    956 
    957     // See if there is a conflict between the candidate and all already placed
    958     // candidates.
    959     bool conflict = false;
    960     SortedIntSet* chosen_indices_for_source_ptr = nullptr;
    961     for (auto& source_set_pair : chosen_indices_for_source_map) {
    962       if (source_set_pair.first == candidates[considered_candidate].source) {
    963         chosen_indices_for_source_ptr = &source_set_pair.second;
    964       }
    965 
    966       if (DoSourcesConflict(annotation_usecase, source_set_pair.first,
    967                             candidates[considered_candidate].source) &&
    968           DoesCandidateConflict(considered_candidate, candidates,
    969                                 source_set_pair.second)) {
    970         conflict = true;
    971         break;
    972       }
    973     }
    974 
    975     // Skip the candidate if a conflict was found.
    976     if (conflict) {
    977       continue;
    978     }
    979 
    980     // If the set of indices for the current source doesn't exist yet,
    981     // initialize it.
    982     if (chosen_indices_for_source_ptr == nullptr) {
    983       SortedIntSet new_set([&candidates](int a, int b) {
    984         return candidates[a].span.first < candidates[b].span.first;
    985       });
    986       chosen_indices_for_source_map[candidates[considered_candidate].source] =
    987           std::move(new_set);
    988       chosen_indices_for_source_ptr =
    989           &chosen_indices_for_source_map[candidates[considered_candidate]
    990                                              .source];
    991     }
    992 
    993     // Place the candidate to the output and to the per-source conflict set.
    994     chosen_indices->push_back(considered_candidate);
    995     chosen_indices_for_source_ptr->insert(considered_candidate);
    996   }
    997 
    998   std::sort(chosen_indices->begin(), chosen_indices->end());
    999 
   1000   return true;
   1001 }
   1002 
   1003 bool Annotator::ModelSuggestSelection(
   1004     const UnicodeText& context_unicode, CodepointSpan click_indices,
   1005     const std::vector<Locale>& detected_text_language_tags,
   1006     InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
   1007     std::vector<AnnotatedSpan>* result) const {
   1008   if (model_->triggering_options() == nullptr ||
   1009       !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
   1010     return true;
   1011   }
   1012 
   1013   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
   1014                                     ml_model_triggering_locales_,
   1015                                     /*default_value=*/true)) {
   1016     return true;
   1017   }
   1018 
   1019   int click_pos;
   1020   *tokens = selection_feature_processor_->Tokenize(context_unicode);
   1021   selection_feature_processor_->RetokenizeAndFindClick(
   1022       context_unicode, click_indices,
   1023       selection_feature_processor_->GetOptions()->only_use_line_with_click(),
   1024       tokens, &click_pos);
   1025   if (click_pos == kInvalidIndex) {
   1026     TC3_VLOG(1) << "Could not calculate the click position.";
   1027     return false;
   1028   }
   1029 
   1030   const int symmetry_context_size =
   1031       model_->selection_options()->symmetry_context_size();
   1032   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
   1033       bounds_sensitive_features = selection_feature_processor_->GetOptions()
   1034                                       ->bounds_sensitive_features();
   1035 
   1036   // The symmetry context span is the clicked token with symmetry_context_size
   1037   // tokens on either side.
   1038   const TokenSpan symmetry_context_span = IntersectTokenSpans(
   1039       ExpandTokenSpan(SingleTokenSpan(click_pos),
   1040                       /*num_tokens_left=*/symmetry_context_size,
   1041                       /*num_tokens_right=*/symmetry_context_size),
   1042       {0, tokens->size()});
   1043 
   1044   // Compute the extraction span based on the model type.
   1045   TokenSpan extraction_span;
   1046   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
   1047     // The extraction span is the symmetry context span expanded to include
   1048     // max_selection_span tokens on either side, which is how far a selection
   1049     // can stretch from the click, plus a relevant number of tokens outside of
   1050     // the bounds of the selection.
   1051     const int max_selection_span =
   1052         selection_feature_processor_->GetOptions()->max_selection_span();
   1053     extraction_span =
   1054         ExpandTokenSpan(symmetry_context_span,
   1055                         /*num_tokens_left=*/max_selection_span +
   1056                             bounds_sensitive_features->num_tokens_before(),
   1057                         /*num_tokens_right=*/max_selection_span +
   1058                             bounds_sensitive_features->num_tokens_after());
   1059   } else {
   1060     // The extraction span is the symmetry context span expanded to include
   1061     // context_size tokens on either side.
   1062     const int context_size =
   1063         selection_feature_processor_->GetOptions()->context_size();
   1064     extraction_span = ExpandTokenSpan(symmetry_context_span,
   1065                                       /*num_tokens_left=*/context_size,
   1066                                       /*num_tokens_right=*/context_size);
   1067   }
   1068   extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
   1069 
   1070   if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
   1071           *tokens, extraction_span)) {
   1072     return true;
   1073   }
   1074 
   1075   std::unique_ptr<CachedFeatures> cached_features;
   1076   if (!selection_feature_processor_->ExtractFeatures(
   1077           *tokens, extraction_span,
   1078           /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
   1079           embedding_executor_.get(),
   1080           /*embedding_cache=*/nullptr,
   1081           selection_feature_processor_->EmbeddingSize() +
   1082               selection_feature_processor_->DenseFeaturesCount(),
   1083           &cached_features)) {
   1084     TC3_LOG(ERROR) << "Could not extract features.";
   1085     return false;
   1086   }
   1087 
   1088   // Produce selection model candidates.
   1089   std::vector<TokenSpan> chunks;
   1090   if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
   1091                   interpreter_manager->SelectionInterpreter(), *cached_features,
   1092                   &chunks)) {
   1093     TC3_LOG(ERROR) << "Could not chunk.";
   1094     return false;
   1095   }
   1096 
   1097   for (const TokenSpan& chunk : chunks) {
   1098     AnnotatedSpan candidate;
   1099     candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
   1100         context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
   1101     if (model_->selection_options()->strip_unpaired_brackets()) {
   1102       candidate.span =
   1103           StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
   1104     }
   1105 
   1106     // Only output non-empty spans.
   1107     if (candidate.span.first != candidate.span.second) {
   1108       result->push_back(candidate);
   1109     }
   1110   }
   1111   return true;
   1112 }
   1113 
   1114 bool Annotator::ModelClassifyText(
   1115     const std::string& context,
   1116     const std::vector<Locale>& detected_text_language_tags,
   1117     CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
   1118     FeatureProcessor::EmbeddingCache* embedding_cache,
   1119     std::vector<ClassificationResult>* classification_results) const {
   1120   return ModelClassifyText(context, {}, detected_text_language_tags,
   1121                            selection_indices, interpreter_manager,
   1122                            embedding_cache, classification_results);
   1123 }
   1124 
   1125 namespace internal {
   1126 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
   1127                                     CodepointSpan selection_indices,
   1128                                     TokenSpan tokens_around_selection_to_copy) {
   1129   const auto first_selection_token = std::upper_bound(
   1130       cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
   1131       [](int selection_start, const Token& token) {
   1132         return selection_start < token.end;
   1133       });
   1134   const auto last_selection_token = std::lower_bound(
   1135       cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
   1136       [](const Token& token, int selection_end) {
   1137         return token.start < selection_end;
   1138       });
   1139 
   1140   const int64 first_token = std::max(
   1141       static_cast<int64>(0),
   1142       static_cast<int64>((first_selection_token - cached_tokens.begin()) -
   1143                          tokens_around_selection_to_copy.first));
   1144   const int64 last_token = std::min(
   1145       static_cast<int64>(cached_tokens.size()),
   1146       static_cast<int64>((last_selection_token - cached_tokens.begin()) +
   1147                          tokens_around_selection_to_copy.second));
   1148 
   1149   std::vector<Token> tokens;
   1150   tokens.reserve(last_token - first_token);
   1151   for (int i = first_token; i < last_token; ++i) {
   1152     tokens.push_back(cached_tokens[i]);
   1153   }
   1154   return tokens;
   1155 }
   1156 }  // namespace internal
   1157 
   1158 TokenSpan Annotator::ClassifyTextUpperBoundNeededTokens() const {
   1159   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
   1160       bounds_sensitive_features =
   1161           classification_feature_processor_->GetOptions()
   1162               ->bounds_sensitive_features();
   1163   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
   1164     // The extraction span is the selection span expanded to include a relevant
   1165     // number of tokens outside of the bounds of the selection.
   1166     return {bounds_sensitive_features->num_tokens_before(),
   1167             bounds_sensitive_features->num_tokens_after()};
   1168   } else {
   1169     // The extraction span is the clicked token with context_size tokens on
   1170     // either side.
   1171     const int context_size =
   1172         selection_feature_processor_->GetOptions()->context_size();
   1173     return {context_size, context_size};
   1174   }
   1175 }
   1176 
   1177 namespace {
   1178 // Sorts the classification results from high score to low score.
   1179 void SortClassificationResults(
   1180     std::vector<ClassificationResult>* classification_results) {
   1181   std::sort(classification_results->begin(), classification_results->end(),
   1182             [](const ClassificationResult& a, const ClassificationResult& b) {
   1183               return a.score > b.score;
   1184             });
   1185 }
   1186 }  // namespace
   1187 
   1188 bool Annotator::ModelClassifyText(
   1189     const std::string& context, const std::vector<Token>& cached_tokens,
   1190     const std::vector<Locale>& detected_text_language_tags,
   1191     CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
   1192     FeatureProcessor::EmbeddingCache* embedding_cache,
   1193     std::vector<ClassificationResult>* classification_results) const {
   1194   std::vector<Token> tokens;
   1195   return ModelClassifyText(context, cached_tokens, detected_text_language_tags,
   1196                            selection_indices, interpreter_manager,
   1197                            embedding_cache, classification_results, &tokens);
   1198 }
   1199 
   1200 bool Annotator::ModelClassifyText(
   1201     const std::string& context, const std::vector<Token>& cached_tokens,
   1202     const std::vector<Locale>& detected_text_language_tags,
   1203     CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
   1204     FeatureProcessor::EmbeddingCache* embedding_cache,
   1205     std::vector<ClassificationResult>* classification_results,
   1206     std::vector<Token>* tokens) const {
   1207   if (model_->triggering_options() == nullptr ||
   1208       !(model_->triggering_options()->enabled_modes() &
   1209         ModeFlag_CLASSIFICATION)) {
   1210     return true;
   1211   }
   1212 
   1213   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
   1214                                     ml_model_triggering_locales_,
   1215                                     /*default_value=*/true)) {
   1216     return true;
   1217   }
   1218 
   1219   if (cached_tokens.empty()) {
   1220     *tokens = classification_feature_processor_->Tokenize(context);
   1221   } else {
   1222     *tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
   1223                                          ClassifyTextUpperBoundNeededTokens());
   1224   }
   1225 
   1226   int click_pos;
   1227   classification_feature_processor_->RetokenizeAndFindClick(
   1228       context, selection_indices,
   1229       classification_feature_processor_->GetOptions()
   1230           ->only_use_line_with_click(),
   1231       tokens, &click_pos);
   1232   const TokenSpan selection_token_span =
   1233       CodepointSpanToTokenSpan(*tokens, selection_indices);
   1234   const int selection_num_tokens = TokenSpanSize(selection_token_span);
   1235   if (model_->classification_options()->max_num_tokens() > 0 &&
   1236       model_->classification_options()->max_num_tokens() <
   1237           selection_num_tokens) {
   1238     *classification_results = {{Collections::Other(), 1.0}};
   1239     return true;
   1240   }
   1241 
   1242   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
   1243       bounds_sensitive_features =
   1244           classification_feature_processor_->GetOptions()
   1245               ->bounds_sensitive_features();
   1246   if (selection_token_span.first == kInvalidIndex ||
   1247       selection_token_span.second == kInvalidIndex) {
   1248     TC3_LOG(ERROR) << "Could not determine span.";
   1249     return false;
   1250   }
   1251 
   1252   // Compute the extraction span based on the model type.
   1253   TokenSpan extraction_span;
   1254   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
   1255     // The extraction span is the selection span expanded to include a relevant
   1256     // number of tokens outside of the bounds of the selection.
   1257     extraction_span = ExpandTokenSpan(
   1258         selection_token_span,
   1259         /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
   1260         /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
   1261   } else {
   1262     if (click_pos == kInvalidIndex) {
   1263       TC3_LOG(ERROR) << "Couldn't choose a click position.";
   1264       return false;
   1265     }
   1266     // The extraction span is the clicked token with context_size tokens on
   1267     // either side.
   1268     const int context_size =
   1269         classification_feature_processor_->GetOptions()->context_size();
   1270     extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
   1271                                       /*num_tokens_left=*/context_size,
   1272                                       /*num_tokens_right=*/context_size);
   1273   }
   1274   extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
   1275 
   1276   if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
   1277           *tokens, extraction_span)) {
   1278     *classification_results = {{Collections::Other(), 1.0}};
   1279     return true;
   1280   }
   1281 
   1282   std::unique_ptr<CachedFeatures> cached_features;
   1283   if (!classification_feature_processor_->ExtractFeatures(
   1284           *tokens, extraction_span, selection_indices,
   1285           embedding_executor_.get(), embedding_cache,
   1286           classification_feature_processor_->EmbeddingSize() +
   1287               classification_feature_processor_->DenseFeaturesCount(),
   1288           &cached_features)) {
   1289     TC3_LOG(ERROR) << "Could not extract features.";
   1290     return false;
   1291   }
   1292 
   1293   std::vector<float> features;
   1294   features.reserve(cached_features->OutputFeaturesSize());
   1295   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
   1296     cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
   1297                                                           &features);
   1298   } else {
   1299     cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
   1300   }
   1301 
   1302   TensorView<float> logits = classification_executor_->ComputeLogits(
   1303       TensorView<float>(features.data(),
   1304                         {1, static_cast<int>(features.size())}),
   1305       interpreter_manager->ClassificationInterpreter());
   1306   if (!logits.is_valid()) {
   1307     TC3_LOG(ERROR) << "Couldn't compute logits.";
   1308     return false;
   1309   }
   1310 
   1311   if (logits.dims() != 2 || logits.dim(0) != 1 ||
   1312       logits.dim(1) != classification_feature_processor_->NumCollections()) {
   1313     TC3_LOG(ERROR) << "Mismatching output";
   1314     return false;
   1315   }
   1316 
   1317   const std::vector<float> scores =
   1318       ComputeSoftmax(logits.data(), logits.dim(1));
   1319 
   1320   if (scores.empty()) {
   1321     *classification_results = {{Collections::Other(), 1.0}};
   1322     return true;
   1323   }
   1324 
   1325   const int best_score_index =
   1326       std::max_element(scores.begin(), scores.end()) - scores.begin();
   1327   const std::string top_collection =
   1328       classification_feature_processor_->LabelToCollection(best_score_index);
   1329 
   1330   // Sanity checks.
   1331   if (top_collection == Collections::Phone()) {
   1332     const int digit_count = CountDigits(context, selection_indices);
   1333     if (digit_count <
   1334             model_->classification_options()->phone_min_num_digits() ||
   1335         digit_count >
   1336             model_->classification_options()->phone_max_num_digits()) {
   1337       *classification_results = {{Collections::Other(), 1.0}};
   1338       return true;
   1339     }
   1340   } else if (top_collection == Collections::Address()) {
   1341     if (selection_num_tokens <
   1342         model_->classification_options()->address_min_num_tokens()) {
   1343       *classification_results = {{Collections::Other(), 1.0}};
   1344       return true;
   1345     }
   1346   } else if (top_collection == Collections::Dictionary()) {
   1347     if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
   1348                                       dictionary_locales_,
   1349                                       /*default_value=*/false)) {
   1350       *classification_results = {{Collections::Other(), 1.0}};
   1351       return true;
   1352     }
   1353   }
   1354 
   1355   *classification_results = {{top_collection, 1.0, scores[best_score_index]}};
   1356   return true;
   1357 }
   1358 
   1359 bool Annotator::RegexClassifyText(
   1360     const std::string& context, CodepointSpan selection_indices,
   1361     std::vector<ClassificationResult>* classification_result) const {
   1362   const std::string selection_text =
   1363       UTF8ToUnicodeText(context, /*do_copy=*/false)
   1364           .UTF8Substring(selection_indices.first, selection_indices.second);
   1365   const UnicodeText selection_text_unicode(
   1366       UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
   1367 
   1368   // Check whether any of the regular expressions match.
   1369   for (const int pattern_id : classification_regex_patterns_) {
   1370     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
   1371     const std::unique_ptr<UniLib::RegexMatcher> matcher =
   1372         regex_pattern.pattern->Matcher(selection_text_unicode);
   1373     int status = UniLib::RegexMatcher::kNoError;
   1374     bool matches;
   1375     if (regex_pattern.config->use_approximate_matching()) {
   1376       matches = matcher->ApproximatelyMatches(&status);
   1377     } else {
   1378       matches = matcher->Matches(&status);
   1379     }
   1380     if (status != UniLib::RegexMatcher::kNoError) {
   1381       return false;
   1382     }
   1383     if (matches && VerifyRegexMatchCandidate(
   1384                        context, regex_pattern.config->verification_options(),
   1385                        selection_text, matcher.get())) {
   1386       classification_result->push_back(
   1387           {regex_pattern.config->collection_name()->str(),
   1388            regex_pattern.config->target_classification_score(),
   1389            regex_pattern.config->priority_score()});
   1390       if (!SerializedEntityDataFromRegexMatch(
   1391               regex_pattern.config, matcher.get(),
   1392               &classification_result->back().serialized_entity_data)) {
   1393         TC3_LOG(ERROR) << "Could not get entity data.";
   1394         return false;
   1395       }
   1396     }
   1397   }
   1398 
   1399   return true;
   1400 }
   1401 
   1402 namespace {
   1403 std::string PickCollectionForDatetime(
   1404     const DatetimeParseResult& datetime_parse_result) {
   1405   switch (datetime_parse_result.granularity) {
   1406     case GRANULARITY_HOUR:
   1407     case GRANULARITY_MINUTE:
   1408     case GRANULARITY_SECOND:
   1409       return Collections::DateTime();
   1410     default:
   1411       return Collections::Date();
   1412   }
   1413 }
   1414 
   1415 std::string CreateDatetimeSerializedEntityData(
   1416     const DatetimeParseResult& parse_result) {
   1417   EntityDataT entity_data;
   1418   entity_data.datetime.reset(new EntityData_::DatetimeT());
   1419   entity_data.datetime->time_ms_utc = parse_result.time_ms_utc;
   1420   entity_data.datetime->granularity =
   1421       static_cast<EntityData_::Datetime_::Granularity>(
   1422           parse_result.granularity);
   1423 
   1424   flatbuffers::FlatBufferBuilder builder;
   1425   FinishEntityDataBuffer(builder, EntityData::Pack(builder, &entity_data));
   1426   return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
   1427                      builder.GetSize());
   1428 }
   1429 }  // namespace
   1430 
   1431 bool Annotator::DatetimeClassifyText(
   1432     const std::string& context, CodepointSpan selection_indices,
   1433     const ClassificationOptions& options,
   1434     std::vector<ClassificationResult>* classification_results) const {
   1435   if (!datetime_parser_) {
   1436     return false;
   1437   }
   1438 
   1439   const std::string selection_text =
   1440       UTF8ToUnicodeText(context, /*do_copy=*/false)
   1441           .UTF8Substring(selection_indices.first, selection_indices.second);
   1442 
   1443   std::vector<DatetimeParseResultSpan> datetime_spans;
   1444   if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
   1445                                options.reference_timezone, options.locales,
   1446                                ModeFlag_CLASSIFICATION,
   1447                                options.annotation_usecase,
   1448                                /*anchor_start_end=*/true, &datetime_spans)) {
   1449     TC3_LOG(ERROR) << "Error during parsing datetime.";
   1450     return false;
   1451   }
   1452   for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
   1453     // Only consider the result valid if the selection and extracted datetime
   1454     // spans exactly match.
   1455     if (std::make_pair(datetime_span.span.first + selection_indices.first,
   1456                        datetime_span.span.second + selection_indices.first) ==
   1457         selection_indices) {
   1458       for (const DatetimeParseResult& parse_result : datetime_span.data) {
   1459         classification_results->emplace_back(
   1460             PickCollectionForDatetime(parse_result),
   1461             datetime_span.target_classification_score);
   1462         classification_results->back().datetime_parse_result = parse_result;
   1463         classification_results->back().serialized_entity_data =
   1464             CreateDatetimeSerializedEntityData(parse_result);
   1465         classification_results->back().priority_score =
   1466             datetime_span.priority_score;
   1467       }
   1468       return true;
   1469     }
   1470   }
   1471   return true;
   1472 }
   1473 
   1474 std::vector<ClassificationResult> Annotator::ClassifyText(
   1475     const std::string& context, CodepointSpan selection_indices,
   1476     const ClassificationOptions& options) const {
   1477   if (!initialized_) {
   1478     TC3_LOG(ERROR) << "Not initialized";
   1479     return {};
   1480   }
   1481 
   1482   if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
   1483     return {};
   1484   }
   1485 
   1486   std::vector<Locale> detected_text_language_tags;
   1487   if (!ParseLocales(options.detected_text_language_tags,
   1488                     &detected_text_language_tags)) {
   1489     TC3_LOG(WARNING)
   1490         << "Failed to parse the detected_text_language_tags in options: "
   1491         << options.detected_text_language_tags;
   1492   }
   1493   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
   1494                                     model_triggering_locales_,
   1495                                     /*default_value=*/true)) {
   1496     return {};
   1497   }
   1498 
   1499   if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
   1500     return {};
   1501   }
   1502 
   1503   if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
   1504     TC3_VLOG(1) << "Trying to run ClassifyText with invalid indices: "
   1505                 << std::get<0>(selection_indices) << " "
   1506                 << std::get<1>(selection_indices);
   1507     return {};
   1508   }
   1509 
   1510   // We'll accumulate a list of candidates, and pick the best candidate in the
   1511   // end.
   1512   std::vector<AnnotatedSpan> candidates;
   1513 
   1514   // Try the knowledge engine.
   1515   // TODO(b/126579108): Propagate error status.
   1516   ClassificationResult knowledge_result;
   1517   if (knowledge_engine_ && knowledge_engine_->ClassifyText(
   1518                                context, selection_indices, &knowledge_result)) {
   1519     candidates.push_back({selection_indices, {knowledge_result}});
   1520     candidates.back().source = AnnotatedSpan::Source::KNOWLEDGE;
   1521   }
   1522 
   1523   // Try the contact engine.
   1524   // TODO(b/126579108): Propagate error status.
   1525   ClassificationResult contact_result;
   1526   if (contact_engine_ && contact_engine_->ClassifyText(
   1527                              context, selection_indices, &contact_result)) {
   1528     candidates.push_back({selection_indices, {contact_result}});
   1529   }
   1530 
   1531   // Try the installed app engine.
   1532   // TODO(b/126579108): Propagate error status.
   1533   ClassificationResult installed_app_result;
   1534   if (installed_app_engine_ &&
   1535       installed_app_engine_->ClassifyText(context, selection_indices,
   1536                                           &installed_app_result)) {
   1537     candidates.push_back({selection_indices, {installed_app_result}});
   1538   }
   1539 
   1540   // Try the regular expression models.
   1541   std::vector<ClassificationResult> regex_results;
   1542   if (!RegexClassifyText(context, selection_indices, &regex_results)) {
   1543     return {};
   1544   }
   1545   for (const ClassificationResult& result : regex_results) {
   1546     candidates.push_back({selection_indices, {result}});
   1547   }
   1548 
   1549   // Try the date model.
   1550   //
   1551   // DatetimeClassifyText only returns the first result, which can however have
   1552   // more interpretations. They are inserted in the candidates as a single
   1553   // AnnotatedSpan, so that they get treated together by the conflict resolution
   1554   // algorithm.
   1555   std::vector<ClassificationResult> datetime_results;
   1556   if (!DatetimeClassifyText(context, selection_indices, options,
   1557                             &datetime_results)) {
   1558     return {};
   1559   }
   1560   if (!datetime_results.empty()) {
   1561     candidates.push_back({selection_indices, std::move(datetime_results)});
   1562     candidates.back().source = AnnotatedSpan::Source::DATETIME;
   1563   }
   1564 
   1565   // Try the number annotator.
   1566   // TODO(b/126579108): Propagate error status.
   1567   ClassificationResult number_annotator_result;
   1568   if (number_annotator_ &&
   1569       number_annotator_->ClassifyText(
   1570           UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
   1571           options.annotation_usecase, &number_annotator_result)) {
   1572     candidates.push_back({selection_indices, {number_annotator_result}});
   1573   }
   1574 
   1575   // Try the duration annotator.
   1576   ClassificationResult duration_annotator_result;
   1577   if (duration_annotator_ &&
   1578       duration_annotator_->ClassifyText(
   1579           UTF8ToUnicodeText(context, /*do_copy=*/false), selection_indices,
   1580           options.annotation_usecase, &duration_annotator_result)) {
   1581     candidates.push_back({selection_indices, {duration_annotator_result}});
   1582     candidates.back().source = AnnotatedSpan::Source::DURATION;
   1583   }
   1584 
   1585   // Try the ML model.
   1586   //
   1587   // The output of the model is considered as an exclusive 1-of-N choice. That's
   1588   // why it's inserted as only 1 AnnotatedSpan into candidates, as opposed to 1
   1589   // span for each candidate, like e.g. the regex model.
   1590   InterpreterManager interpreter_manager(selection_executor_.get(),
   1591                                          classification_executor_.get());
   1592   std::vector<ClassificationResult> model_results;
   1593   std::vector<Token> tokens;
   1594   if (!ModelClassifyText(
   1595           context, /*cached_tokens=*/{}, detected_text_language_tags,
   1596           selection_indices, &interpreter_manager,
   1597           /*embedding_cache=*/nullptr, &model_results, &tokens)) {
   1598     return {};
   1599   }
   1600   if (!model_results.empty()) {
   1601     candidates.push_back({selection_indices, std::move(model_results)});
   1602   }
   1603 
   1604   std::vector<int> candidate_indices;
   1605   if (!ResolveConflicts(candidates, context, tokens,
   1606                         detected_text_language_tags, options.annotation_usecase,
   1607                         &interpreter_manager, &candidate_indices)) {
   1608     TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
   1609     return {};
   1610   }
   1611 
   1612   std::vector<ClassificationResult> results;
   1613   for (const int i : candidate_indices) {
   1614     for (const ClassificationResult& result : candidates[i].classification) {
   1615       if (!FilteredForClassification(result)) {
   1616         results.push_back(result);
   1617       }
   1618     }
   1619   }
   1620 
   1621   // Sort results according to score.
   1622   std::sort(results.begin(), results.end(),
   1623             [](const ClassificationResult& a, const ClassificationResult& b) {
   1624               return a.score > b.score;
   1625             });
   1626 
   1627   if (results.empty()) {
   1628     results = {{Collections::Other(), 1.0}};
   1629   }
   1630   return results;
   1631 }
   1632 
   1633 bool Annotator::ModelAnnotate(
   1634     const std::string& context,
   1635     const std::vector<Locale>& detected_text_language_tags,
   1636     InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
   1637     std::vector<AnnotatedSpan>* result) const {
   1638   if (model_->triggering_options() == nullptr ||
   1639       !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
   1640     return true;
   1641   }
   1642 
   1643   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
   1644                                     ml_model_triggering_locales_,
   1645                                     /*default_value=*/true)) {
   1646     return true;
   1647   }
   1648 
   1649   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
   1650                                                         /*do_copy=*/false);
   1651   std::vector<UnicodeTextRange> lines;
   1652   if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
   1653     lines.push_back({context_unicode.begin(), context_unicode.end()});
   1654   } else {
   1655     lines = selection_feature_processor_->SplitContext(context_unicode);
   1656   }
   1657 
   1658   const float min_annotate_confidence =
   1659       (model_->triggering_options() != nullptr
   1660            ? model_->triggering_options()->min_annotate_confidence()
   1661            : 0.f);
   1662 
   1663   for (const UnicodeTextRange& line : lines) {
   1664     FeatureProcessor::EmbeddingCache embedding_cache;
   1665     const std::string line_str =
   1666         UnicodeText::UTF8Substring(line.first, line.second);
   1667 
   1668     *tokens = selection_feature_processor_->Tokenize(line_str);
   1669     selection_feature_processor_->RetokenizeAndFindClick(
   1670         line_str, {0, std::distance(line.first, line.second)},
   1671         selection_feature_processor_->GetOptions()->only_use_line_with_click(),
   1672         tokens,
   1673         /*click_pos=*/nullptr);
   1674     const TokenSpan full_line_span = {0, tokens->size()};
   1675 
   1676     // TODO(zilka): Add support for greater granularity of this check.
   1677     if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
   1678             *tokens, full_line_span)) {
   1679       continue;
   1680     }
   1681 
   1682     std::unique_ptr<CachedFeatures> cached_features;
   1683     if (!selection_feature_processor_->ExtractFeatures(
   1684             *tokens, full_line_span,
   1685             /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
   1686             embedding_executor_.get(),
   1687             /*embedding_cache=*/nullptr,
   1688             selection_feature_processor_->EmbeddingSize() +
   1689                 selection_feature_processor_->DenseFeaturesCount(),
   1690             &cached_features)) {
   1691       TC3_LOG(ERROR) << "Could not extract features.";
   1692       return false;
   1693     }
   1694 
   1695     std::vector<TokenSpan> local_chunks;
   1696     if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
   1697                     interpreter_manager->SelectionInterpreter(),
   1698                     *cached_features, &local_chunks)) {
   1699       TC3_LOG(ERROR) << "Could not chunk.";
   1700       return false;
   1701     }
   1702 
   1703     const int offset = std::distance(context_unicode.begin(), line.first);
   1704     for (const TokenSpan& chunk : local_chunks) {
   1705       const CodepointSpan codepoint_span =
   1706           selection_feature_processor_->StripBoundaryCodepoints(
   1707               line_str, TokenSpanToCodepointSpan(*tokens, chunk));
   1708 
   1709       // Skip empty spans.
   1710       if (codepoint_span.first != codepoint_span.second) {
   1711         std::vector<ClassificationResult> classification;
   1712         if (!ModelClassifyText(line_str, *tokens, detected_text_language_tags,
   1713                                codepoint_span, interpreter_manager,
   1714                                &embedding_cache, &classification)) {
   1715           TC3_LOG(ERROR) << "Could not classify text: "
   1716                          << (codepoint_span.first + offset) << " "
   1717                          << (codepoint_span.second + offset);
   1718           return false;
   1719         }
   1720 
   1721         // Do not include the span if it's classified as "other".
   1722         if (!classification.empty() && !ClassifiedAsOther(classification) &&
   1723             classification[0].score >= min_annotate_confidence) {
   1724           AnnotatedSpan result_span;
   1725           result_span.span = {codepoint_span.first + offset,
   1726                               codepoint_span.second + offset};
   1727           result_span.classification = std::move(classification);
   1728           result->push_back(std::move(result_span));
   1729         }
   1730       }
   1731     }
   1732   }
   1733   return true;
   1734 }
   1735 
   1736 const FeatureProcessor* Annotator::SelectionFeatureProcessorForTests() const {
   1737   return selection_feature_processor_.get();
   1738 }
   1739 
   1740 const FeatureProcessor* Annotator::ClassificationFeatureProcessorForTests()
   1741     const {
   1742   return classification_feature_processor_.get();
   1743 }
   1744 
   1745 const DatetimeParser* Annotator::DatetimeParserForTests() const {
   1746   return datetime_parser_.get();
   1747 }
   1748 
   1749 void Annotator::RemoveNotEnabledEntityTypes(
   1750     const EnabledEntityTypes& is_entity_type_enabled,
   1751     std::vector<AnnotatedSpan>* annotated_spans) const {
   1752   for (AnnotatedSpan& annotated_span : *annotated_spans) {
   1753     std::vector<ClassificationResult>& classifications =
   1754         annotated_span.classification;
   1755     classifications.erase(
   1756         std::remove_if(classifications.begin(), classifications.end(),
   1757                        [&is_entity_type_enabled](
   1758                            const ClassificationResult& classification_result) {
   1759                          return !is_entity_type_enabled(
   1760                              classification_result.collection);
   1761                        }),
   1762         classifications.end());
   1763   }
   1764   annotated_spans->erase(
   1765       std::remove_if(annotated_spans->begin(), annotated_spans->end(),
   1766                      [](const AnnotatedSpan& annotated_span) {
   1767                        return annotated_span.classification.empty();
   1768                      }),
   1769       annotated_spans->end());
   1770 }
   1771 
   1772 std::vector<AnnotatedSpan> Annotator::Annotate(
   1773     const std::string& context, const AnnotationOptions& options) const {
   1774   std::vector<AnnotatedSpan> candidates;
   1775 
   1776   if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
   1777     return {};
   1778   }
   1779 
   1780   const UnicodeText context_unicode =
   1781       UTF8ToUnicodeText(context, /*do_copy=*/false);
   1782   if (!context_unicode.is_valid()) {
   1783     return {};
   1784   }
   1785 
   1786   std::vector<Locale> detected_text_language_tags;
   1787   if (!ParseLocales(options.detected_text_language_tags,
   1788                     &detected_text_language_tags)) {
   1789     TC3_LOG(WARNING)
   1790         << "Failed to parse the detected_text_language_tags in options: "
   1791         << options.detected_text_language_tags;
   1792   }
   1793   if (!Locale::IsAnyLocaleSupported(detected_text_language_tags,
   1794                                     model_triggering_locales_,
   1795                                     /*default_value=*/true)) {
   1796     return {};
   1797   }
   1798 
   1799   InterpreterManager interpreter_manager(selection_executor_.get(),
   1800                                          classification_executor_.get());
   1801 
   1802   // Annotate with the selection model.
   1803   std::vector<Token> tokens;
   1804   if (!ModelAnnotate(context, detected_text_language_tags, &interpreter_manager,
   1805                      &tokens, &candidates)) {
   1806     TC3_LOG(ERROR) << "Couldn't run ModelAnnotate.";
   1807     return {};
   1808   }
   1809 
   1810   // Annotate with the regular expression models.
   1811   if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
   1812                   annotation_regex_patterns_, &candidates,
   1813                   options.is_serialized_entity_data_enabled)) {
   1814     TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
   1815     return {};
   1816   }
   1817 
   1818   // Annotate with the datetime model.
   1819   const EnabledEntityTypes is_entity_type_enabled(options.entity_types);
   1820   if ((is_entity_type_enabled(Collections::Date()) ||
   1821        is_entity_type_enabled(Collections::DateTime())) &&
   1822       !DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
   1823                      options.reference_time_ms_utc, options.reference_timezone,
   1824                      options.locales, ModeFlag_ANNOTATION,
   1825                      options.annotation_usecase,
   1826                      options.is_serialized_entity_data_enabled, &candidates)) {
   1827     TC3_LOG(ERROR) << "Couldn't run RegexChunk.";
   1828     return {};
   1829   }
   1830 
   1831   // Annotate with the knowledge engine.
   1832   if (knowledge_engine_ && !knowledge_engine_->Chunk(context, &candidates)) {
   1833     TC3_LOG(ERROR) << "Couldn't run knowledge engine Chunk.";
   1834     return {};
   1835   }
   1836 
   1837   // Annotate with the contact engine.
   1838   if (contact_engine_ &&
   1839       !contact_engine_->Chunk(context_unicode, tokens, &candidates)) {
   1840     TC3_LOG(ERROR) << "Couldn't run contact engine Chunk.";
   1841     return {};
   1842   }
   1843 
   1844   // Annotate with the installed app engine.
   1845   if (installed_app_engine_ &&
   1846       !installed_app_engine_->Chunk(context_unicode, tokens, &candidates)) {
   1847     TC3_LOG(ERROR) << "Couldn't run installed app engine Chunk.";
   1848     return {};
   1849   }
   1850 
   1851   // Annotate with the number annotator.
   1852   if (number_annotator_ != nullptr &&
   1853       !number_annotator_->FindAll(context_unicode, options.annotation_usecase,
   1854                                   &candidates)) {
   1855     TC3_LOG(ERROR) << "Couldn't run number annotator FindAll.";
   1856     return {};
   1857   }
   1858 
   1859   // Annotate with the duration annotator.
   1860   if (is_entity_type_enabled(Collections::Duration()) &&
   1861       duration_annotator_ != nullptr &&
   1862       !duration_annotator_->FindAll(context_unicode, tokens,
   1863                                     options.annotation_usecase, &candidates)) {
   1864     TC3_LOG(ERROR) << "Couldn't run duration annotator FindAll.";
   1865     return {};
   1866   }
   1867 
   1868   // Sort candidates according to their position in the input, so that the next
   1869   // code can assume that any connected component of overlapping spans forms a
   1870   // contiguous block.
   1871   std::sort(candidates.begin(), candidates.end(),
   1872             [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
   1873               return a.span.first < b.span.first;
   1874             });
   1875 
   1876   std::vector<int> candidate_indices;
   1877   if (!ResolveConflicts(candidates, context, tokens,
   1878                         detected_text_language_tags, options.annotation_usecase,
   1879                         &interpreter_manager, &candidate_indices)) {
   1880     TC3_LOG(ERROR) << "Couldn't resolve conflicts.";
   1881     return {};
   1882   }
   1883 
   1884   std::vector<AnnotatedSpan> result;
   1885   result.reserve(candidate_indices.size());
   1886   AnnotatedSpan aggregated_span;
   1887   for (const int i : candidate_indices) {
   1888     if (candidates[i].span != aggregated_span.span) {
   1889       if (!aggregated_span.classification.empty()) {
   1890         result.push_back(std::move(aggregated_span));
   1891       }
   1892       aggregated_span =
   1893           AnnotatedSpan(candidates[i].span, /*arg_classification=*/{});
   1894     }
   1895     if (candidates[i].classification.empty() ||
   1896         ClassifiedAsOther(candidates[i].classification) ||
   1897         FilteredForAnnotation(candidates[i])) {
   1898       continue;
   1899     }
   1900     for (ClassificationResult& classification : candidates[i].classification) {
   1901       aggregated_span.classification.push_back(std::move(classification));
   1902     }
   1903   }
   1904   if (!aggregated_span.classification.empty()) {
   1905     result.push_back(std::move(aggregated_span));
   1906   }
   1907 
   1908   // We generate all candidates and remove them later (with the exception of
   1909   // date/time/duration entities) because there are complex interdependencies
   1910   // between the entity types. E.g., the TLD of an email can be interpreted as a
   1911   // URL, but most likely a user of the API does not want such annotations if
   1912   // "url" is enabled and "email" is not.
   1913   RemoveNotEnabledEntityTypes(is_entity_type_enabled, &result);
   1914 
   1915   for (AnnotatedSpan& annotated_span : result) {
   1916     SortClassificationResults(&annotated_span.classification);
   1917   }
   1918 
   1919   return result;
   1920 }
   1921 
   1922 CodepointSpan Annotator::ComputeSelectionBoundaries(
   1923     const UniLib::RegexMatcher* match,
   1924     const RegexModel_::Pattern* config) const {
   1925   if (config->capturing_group() == nullptr) {
   1926     // Use first capturing group to specify the selection.
   1927     int status = UniLib::RegexMatcher::kNoError;
   1928     const CodepointSpan result = {match->Start(1, &status),
   1929                                   match->End(1, &status)};
   1930     if (status != UniLib::RegexMatcher::kNoError) {
   1931       return {kInvalidIndex, kInvalidIndex};
   1932     }
   1933     return result;
   1934   }
   1935 
   1936   CodepointSpan result = {kInvalidIndex, kInvalidIndex};
   1937   const int num_groups = config->capturing_group()->size();
   1938   for (int i = 0; i < num_groups; i++) {
   1939     if (!config->capturing_group()->Get(i)->extend_selection()) {
   1940       continue;
   1941     }
   1942 
   1943     int status = UniLib::RegexMatcher::kNoError;
   1944     // Check match and adjust bounds.
   1945     const int group_start = match->Start(i, &status);
   1946     const int group_end = match->End(i, &status);
   1947     if (status != UniLib::RegexMatcher::kNoError) {
   1948       return {kInvalidIndex, kInvalidIndex};
   1949     }
   1950     if (group_start == kInvalidIndex || group_end == kInvalidIndex) {
   1951       continue;
   1952     }
   1953     if (result.first == kInvalidIndex) {
   1954       result = {group_start, group_end};
   1955     } else {
   1956       result.first = std::min(result.first, group_start);
   1957       result.second = std::max(result.second, group_end);
   1958     }
   1959   }
   1960   return result;
   1961 }
   1962 
   1963 bool Annotator::HasEntityData(const RegexModel_::Pattern* pattern) const {
   1964   if (pattern->serialized_entity_data() != nullptr) {
   1965     return true;
   1966   }
   1967   if (pattern->capturing_group() != nullptr) {
   1968     for (const RegexModel_::Pattern_::CapturingGroup* group :
   1969          *pattern->capturing_group()) {
   1970       if (group->entity_field_path() != nullptr) {
   1971         return true;
   1972       }
   1973     }
   1974   }
   1975   return false;
   1976 }
   1977 
   1978 bool Annotator::SerializedEntityDataFromRegexMatch(
   1979     const RegexModel_::Pattern* pattern, UniLib::RegexMatcher* matcher,
   1980     std::string* serialized_entity_data) const {
   1981   if (!HasEntityData(pattern)) {
   1982     serialized_entity_data->clear();
   1983     return true;
   1984   }
   1985   TC3_CHECK(entity_data_builder_ != nullptr);
   1986 
   1987   std::unique_ptr<ReflectiveFlatbuffer> entity_data =
   1988       entity_data_builder_->NewRoot();
   1989 
   1990   TC3_CHECK(entity_data != nullptr);
   1991 
   1992   // Set static entity data.
   1993   if (pattern->serialized_entity_data() != nullptr) {
   1994     TC3_CHECK(entity_data != nullptr);
   1995     entity_data->MergeFromSerializedFlatbuffer(
   1996         StringPiece(pattern->serialized_entity_data()->c_str(),
   1997                     pattern->serialized_entity_data()->size()));
   1998   }
   1999 
   2000   // Add entity data from rule capturing groups.
   2001   if (pattern->capturing_group() != nullptr) {
   2002     const int num_groups = pattern->capturing_group()->size();
   2003     for (int i = 0; i < num_groups; i++) {
   2004       const FlatbufferFieldPath* field_path =
   2005           pattern->capturing_group()->Get(i)->entity_field_path();
   2006       if (field_path == nullptr) {
   2007         continue;
   2008       }
   2009       TC3_CHECK(entity_data != nullptr);
   2010       if (!SetFieldFromCapturingGroup(/*group_id=*/i, field_path, matcher,
   2011                                       entity_data.get())) {
   2012         TC3_LOG(ERROR)
   2013             << "Could not set entity data from rule capturing group.";
   2014         return false;
   2015       }
   2016     }
   2017   }
   2018 
   2019   *serialized_entity_data = entity_data->Serialize();
   2020   return true;
   2021 }
   2022 
   2023 bool Annotator::RegexChunk(const UnicodeText& context_unicode,
   2024                            const std::vector<int>& rules,
   2025                            std::vector<AnnotatedSpan>* result,
   2026                            bool is_serialized_entity_data_enabled) const {
   2027   for (int pattern_id : rules) {
   2028     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
   2029     const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
   2030     if (!matcher) {
   2031       TC3_LOG(ERROR) << "Could not get regex matcher for pattern: "
   2032                      << pattern_id;
   2033       return false;
   2034     }
   2035 
   2036     int status = UniLib::RegexMatcher::kNoError;
   2037     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
   2038       if (regex_pattern.config->verification_options()) {
   2039         if (!VerifyRegexMatchCandidate(
   2040                 context_unicode.ToUTF8String(),
   2041                 regex_pattern.config->verification_options(),
   2042                 matcher->Group(1, &status).ToUTF8String(), matcher.get())) {
   2043           continue;
   2044         }
   2045       }
   2046 
   2047       std::string serialized_entity_data;
   2048       if (is_serialized_entity_data_enabled) {
   2049         if (!SerializedEntityDataFromRegexMatch(
   2050                 regex_pattern.config, matcher.get(), &serialized_entity_data)) {
   2051           TC3_LOG(ERROR) << "Could not get entity data.";
   2052           return false;
   2053         }
   2054       }
   2055 
   2056       result->emplace_back();
   2057 
   2058       // Selection/annotation regular expressions need to specify a capturing
   2059       // group specifying the selection.
   2060       result->back().span =
   2061           ComputeSelectionBoundaries(matcher.get(), regex_pattern.config);
   2062 
   2063       result->back().classification = {
   2064           {regex_pattern.config->collection_name()->str(),
   2065            regex_pattern.config->target_classification_score(),
   2066            regex_pattern.config->priority_score()}};
   2067 
   2068       result->back().classification[0].serialized_entity_data =
   2069           serialized_entity_data;
   2070     }
   2071   }
   2072   return true;
   2073 }
   2074 
   2075 bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
   2076                            tflite::Interpreter* selection_interpreter,
   2077                            const CachedFeatures& cached_features,
   2078                            std::vector<TokenSpan>* chunks) const {
   2079   const int max_selection_span =
   2080       selection_feature_processor_->GetOptions()->max_selection_span();
   2081   // The inference span is the span of interest expanded to include
   2082   // max_selection_span tokens on either side, which is how far a selection can
   2083   // stretch from the click.
   2084   const TokenSpan inference_span = IntersectTokenSpans(
   2085       ExpandTokenSpan(span_of_interest,
   2086                       /*num_tokens_left=*/max_selection_span,
   2087                       /*num_tokens_right=*/max_selection_span),
   2088       {0, num_tokens});
   2089 
   2090   std::vector<ScoredChunk> scored_chunks;
   2091   if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
   2092       selection_feature_processor_->GetOptions()
   2093           ->bounds_sensitive_features()
   2094           ->enabled()) {
   2095     if (!ModelBoundsSensitiveScoreChunks(
   2096             num_tokens, span_of_interest, inference_span, cached_features,
   2097             selection_interpreter, &scored_chunks)) {
   2098       return false;
   2099     }
   2100   } else {
   2101     if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
   2102                                       cached_features, selection_interpreter,
   2103                                       &scored_chunks)) {
   2104       return false;
   2105     }
   2106   }
   2107   std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
   2108             [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
   2109               return lhs.score < rhs.score;
   2110             });
   2111 
   2112   // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
   2113   // them greedily as long as they do not overlap with any previously picked
   2114   // chunks.
   2115   std::vector<bool> token_used(TokenSpanSize(inference_span));
   2116   chunks->clear();
   2117   for (const ScoredChunk& scored_chunk : scored_chunks) {
   2118     bool feasible = true;
   2119     for (int i = scored_chunk.token_span.first;
   2120          i < scored_chunk.token_span.second; ++i) {
   2121       if (token_used[i - inference_span.first]) {
   2122         feasible = false;
   2123         break;
   2124       }
   2125     }
   2126 
   2127     if (!feasible) {
   2128       continue;
   2129     }
   2130 
   2131     for (int i = scored_chunk.token_span.first;
   2132          i < scored_chunk.token_span.second; ++i) {
   2133       token_used[i - inference_span.first] = true;
   2134     }
   2135 
   2136     chunks->push_back(scored_chunk.token_span);
   2137   }
   2138 
   2139   std::sort(chunks->begin(), chunks->end());
   2140 
   2141   return true;
   2142 }
   2143 
   2144 namespace {
   2145 // Updates the value at the given key in the map to maximum of the current value
   2146 // and the given value, or simply inserts the value if the key is not yet there.
   2147 template <typename Map>
   2148 void UpdateMax(Map* map, typename Map::key_type key,
   2149                typename Map::mapped_type value) {
   2150   const auto it = map->find(key);
   2151   if (it != map->end()) {
   2152     it->second = std::max(it->second, value);
   2153   } else {
   2154     (*map)[key] = value;
   2155   }
   2156 }
   2157 }  // namespace
   2158 
   2159 bool Annotator::ModelClickContextScoreChunks(
   2160     int num_tokens, const TokenSpan& span_of_interest,
   2161     const CachedFeatures& cached_features,
   2162     tflite::Interpreter* selection_interpreter,
   2163     std::vector<ScoredChunk>* scored_chunks) const {
   2164   const int max_batch_size = model_->selection_options()->batch_size();
   2165 
   2166   std::vector<float> all_features;
   2167   std::map<TokenSpan, float> chunk_scores;
   2168   for (int batch_start = span_of_interest.first;
   2169        batch_start < span_of_interest.second; batch_start += max_batch_size) {
   2170     const int batch_end =
   2171         std::min(batch_start + max_batch_size, span_of_interest.second);
   2172 
   2173     // Prepare features for the whole batch.
   2174     all_features.clear();
   2175     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
   2176     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
   2177       cached_features.AppendClickContextFeaturesForClick(click_pos,
   2178                                                          &all_features);
   2179     }
   2180 
   2181     // Run batched inference.
   2182     const int batch_size = batch_end - batch_start;
   2183     const int features_size = cached_features.OutputFeaturesSize();
   2184     TensorView<float> logits = selection_executor_->ComputeLogits(
   2185         TensorView<float>(all_features.data(), {batch_size, features_size}),
   2186         selection_interpreter);
   2187     if (!logits.is_valid()) {
   2188       TC3_LOG(ERROR) << "Couldn't compute logits.";
   2189       return false;
   2190     }
   2191     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
   2192         logits.dim(1) !=
   2193             selection_feature_processor_->GetSelectionLabelCount()) {
   2194       TC3_LOG(ERROR) << "Mismatching output.";
   2195       return false;
   2196     }
   2197 
   2198     // Save results.
   2199     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
   2200       const std::vector<float> scores = ComputeSoftmax(
   2201           logits.data() + logits.dim(1) * (click_pos - batch_start),
   2202           logits.dim(1));
   2203       for (int j = 0;
   2204            j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
   2205         TokenSpan relative_token_span;
   2206         if (!selection_feature_processor_->LabelToTokenSpan(
   2207                 j, &relative_token_span)) {
   2208           TC3_LOG(ERROR) << "Couldn't map the label to a token span.";
   2209           return false;
   2210         }
   2211         const TokenSpan candidate_span = ExpandTokenSpan(
   2212             SingleTokenSpan(click_pos), relative_token_span.first,
   2213             relative_token_span.second);
   2214         if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
   2215           UpdateMax(&chunk_scores, candidate_span, scores[j]);
   2216         }
   2217       }
   2218     }
   2219   }
   2220 
   2221   scored_chunks->clear();
   2222   scored_chunks->reserve(chunk_scores.size());
   2223   for (const auto& entry : chunk_scores) {
   2224     scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
   2225   }
   2226 
   2227   return true;
   2228 }
   2229 
   2230 bool Annotator::ModelBoundsSensitiveScoreChunks(
   2231     int num_tokens, const TokenSpan& span_of_interest,
   2232     const TokenSpan& inference_span, const CachedFeatures& cached_features,
   2233     tflite::Interpreter* selection_interpreter,
   2234     std::vector<ScoredChunk>* scored_chunks) const {
   2235   const int max_selection_span =
   2236       selection_feature_processor_->GetOptions()->max_selection_span();
   2237   const int max_chunk_length = selection_feature_processor_->GetOptions()
   2238                                        ->selection_reduced_output_space()
   2239                                    ? max_selection_span + 1
   2240                                    : 2 * max_selection_span + 1;
   2241   const bool score_single_token_spans_as_zero =
   2242       selection_feature_processor_->GetOptions()
   2243           ->bounds_sensitive_features()
   2244           ->score_single_token_spans_as_zero();
   2245 
   2246   scored_chunks->clear();
   2247   if (score_single_token_spans_as_zero) {
   2248     scored_chunks->reserve(TokenSpanSize(span_of_interest));
   2249   }
   2250 
   2251   // Prepare all chunk candidates into one batch:
   2252   //   - Are contained in the inference span
   2253   //   - Have a non-empty intersection with the span of interest
   2254   //   - Are at least one token long
   2255   //   - Are not longer than the maximum chunk length
   2256   std::vector<TokenSpan> candidate_spans;
   2257   for (int start = inference_span.first; start < span_of_interest.second;
   2258        ++start) {
   2259     const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
   2260     for (int end = leftmost_end_index;
   2261          end <= inference_span.second && end - start <= max_chunk_length;
   2262          ++end) {
   2263       const TokenSpan candidate_span = {start, end};
   2264       if (score_single_token_spans_as_zero &&
   2265           TokenSpanSize(candidate_span) == 1) {
   2266         // Do not include the single token span in the batch, add a zero score
   2267         // for it directly to the output.
   2268         scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
   2269       } else {
   2270         candidate_spans.push_back(candidate_span);
   2271       }
   2272     }
   2273   }
   2274 
   2275   const int max_batch_size = model_->selection_options()->batch_size();
   2276 
   2277   std::vector<float> all_features;
   2278   scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
   2279   for (int batch_start = 0; batch_start < candidate_spans.size();
   2280        batch_start += max_batch_size) {
   2281     const int batch_end = std::min(batch_start + max_batch_size,
   2282                                    static_cast<int>(candidate_spans.size()));
   2283 
   2284     // Prepare features for the whole batch.
   2285     all_features.clear();
   2286     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
   2287     for (int i = batch_start; i < batch_end; ++i) {
   2288       cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
   2289                                                            &all_features);
   2290     }
   2291 
   2292     // Run batched inference.
   2293     const int batch_size = batch_end - batch_start;
   2294     const int features_size = cached_features.OutputFeaturesSize();
   2295     TensorView<float> logits = selection_executor_->ComputeLogits(
   2296         TensorView<float>(all_features.data(), {batch_size, features_size}),
   2297         selection_interpreter);
   2298     if (!logits.is_valid()) {
   2299       TC3_LOG(ERROR) << "Couldn't compute logits.";
   2300       return false;
   2301     }
   2302     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
   2303         logits.dim(1) != 1) {
   2304       TC3_LOG(ERROR) << "Mismatching output.";
   2305       return false;
   2306     }
   2307 
   2308     // Save results.
   2309     for (int i = batch_start; i < batch_end; ++i) {
   2310       scored_chunks->push_back(
   2311           ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
   2312     }
   2313   }
   2314 
   2315   return true;
   2316 }
   2317 
   2318 bool Annotator::DatetimeChunk(const UnicodeText& context_unicode,
   2319                               int64 reference_time_ms_utc,
   2320                               const std::string& reference_timezone,
   2321                               const std::string& locales, ModeFlag mode,
   2322                               AnnotationUsecase annotation_usecase,
   2323                               bool is_serialized_entity_data_enabled,
   2324                               std::vector<AnnotatedSpan>* result) const {
   2325   if (!datetime_parser_) {
   2326     return true;
   2327   }
   2328 
   2329   std::vector<DatetimeParseResultSpan> datetime_spans;
   2330   if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
   2331                                reference_timezone, locales, mode,
   2332                                annotation_usecase,
   2333                                /*anchor_start_end=*/false, &datetime_spans)) {
   2334     return false;
   2335   }
   2336   for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
   2337     AnnotatedSpan annotated_span;
   2338     annotated_span.span = datetime_span.span;
   2339     for (const DatetimeParseResult& parse_result : datetime_span.data) {
   2340       annotated_span.classification.emplace_back(
   2341           PickCollectionForDatetime(parse_result),
   2342           datetime_span.target_classification_score,
   2343           datetime_span.priority_score);
   2344       annotated_span.classification.back().datetime_parse_result = parse_result;
   2345       if (is_serialized_entity_data_enabled) {
   2346         annotated_span.classification.back().serialized_entity_data =
   2347             CreateDatetimeSerializedEntityData(parse_result);
   2348       }
   2349     }
   2350     annotated_span.source = AnnotatedSpan::Source::DATETIME;
   2351     result->push_back(std::move(annotated_span));
   2352   }
   2353   return true;
   2354 }
   2355 
   2356 const Model* Annotator::model() const { return model_; }
   2357 const reflection::Schema* Annotator::entity_data_schema() const {
   2358   return entity_data_schema_;
   2359 }
   2360 
   2361 const Model* ViewModel(const void* buffer, int size) {
   2362   if (!buffer) {
   2363     return nullptr;
   2364   }
   2365 
   2366   return LoadAndVerifyModel(buffer, size);
   2367 }
   2368 
   2369 bool Annotator::LookUpKnowledgeEntity(
   2370     const std::string& id, std::string* serialized_knowledge_result) const {
   2371   return knowledge_engine_ &&
   2372          knowledge_engine_->LookUpEntity(id, serialized_knowledge_result);
   2373 }
   2374 
   2375 }  // namespace libtextclassifier3
   2376