Home | History | Annotate | Download | only in libtextclassifier
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "text-classifier.h"
     18 
     19 #include <algorithm>
     20 #include <cctype>
     21 #include <cmath>
     22 #include <iterator>
     23 #include <numeric>
     24 
     25 #include "util/base/logging.h"
     26 #include "util/math/softmax.h"
     27 #include "util/utf8/unicodetext.h"
     28 
     29 namespace libtextclassifier2 {
     30 const std::string& TextClassifier::kOtherCollection =
     31     *[]() { return new std::string("other"); }();
     32 const std::string& TextClassifier::kPhoneCollection =
     33     *[]() { return new std::string("phone"); }();
     34 const std::string& TextClassifier::kAddressCollection =
     35     *[]() { return new std::string("address"); }();
     36 const std::string& TextClassifier::kDateCollection =
     37     *[]() { return new std::string("date"); }();
     38 
     39 namespace {
     40 const Model* LoadAndVerifyModel(const void* addr, int size) {
     41   const Model* model = GetModel(addr);
     42 
     43   flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(addr), size);
     44   if (model->Verify(verifier)) {
     45     return model;
     46   } else {
     47     return nullptr;
     48   }
     49 }
     50 }  // namespace
     51 
     52 tflite::Interpreter* InterpreterManager::SelectionInterpreter() {
     53   if (!selection_interpreter_) {
     54     TC_CHECK(selection_executor_);
     55     selection_interpreter_ = selection_executor_->CreateInterpreter();
     56     if (!selection_interpreter_) {
     57       TC_LOG(ERROR) << "Could not build TFLite interpreter.";
     58     }
     59   }
     60   return selection_interpreter_.get();
     61 }
     62 
     63 tflite::Interpreter* InterpreterManager::ClassificationInterpreter() {
     64   if (!classification_interpreter_) {
     65     TC_CHECK(classification_executor_);
     66     classification_interpreter_ = classification_executor_->CreateInterpreter();
     67     if (!classification_interpreter_) {
     68       TC_LOG(ERROR) << "Could not build TFLite interpreter.";
     69     }
     70   }
     71   return classification_interpreter_.get();
     72 }
     73 
     74 std::unique_ptr<TextClassifier> TextClassifier::FromUnownedBuffer(
     75     const char* buffer, int size, const UniLib* unilib) {
     76   const Model* model = LoadAndVerifyModel(buffer, size);
     77   if (model == nullptr) {
     78     return nullptr;
     79   }
     80 
     81   auto classifier =
     82       std::unique_ptr<TextClassifier>(new TextClassifier(model, unilib));
     83   if (!classifier->IsInitialized()) {
     84     return nullptr;
     85   }
     86 
     87   return classifier;
     88 }
     89 
     90 std::unique_ptr<TextClassifier> TextClassifier::FromScopedMmap(
     91     std::unique_ptr<ScopedMmap>* mmap, const UniLib* unilib) {
     92   if (!(*mmap)->handle().ok()) {
     93     TC_VLOG(1) << "Mmap failed.";
     94     return nullptr;
     95   }
     96 
     97   const Model* model = LoadAndVerifyModel((*mmap)->handle().start(),
     98                                           (*mmap)->handle().num_bytes());
     99   if (!model) {
    100     TC_LOG(ERROR) << "Model verification failed.";
    101     return nullptr;
    102   }
    103 
    104   auto classifier =
    105       std::unique_ptr<TextClassifier>(new TextClassifier(mmap, model, unilib));
    106   if (!classifier->IsInitialized()) {
    107     return nullptr;
    108   }
    109 
    110   return classifier;
    111 }
    112 
    113 std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor(
    114     int fd, int offset, int size, const UniLib* unilib) {
    115   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd, offset, size));
    116   return FromScopedMmap(&mmap, unilib);
    117 }
    118 
    119 std::unique_ptr<TextClassifier> TextClassifier::FromFileDescriptor(
    120     int fd, const UniLib* unilib) {
    121   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(fd));
    122   return FromScopedMmap(&mmap, unilib);
    123 }
    124 
    125 std::unique_ptr<TextClassifier> TextClassifier::FromPath(
    126     const std::string& path, const UniLib* unilib) {
    127   std::unique_ptr<ScopedMmap> mmap(new ScopedMmap(path));
    128   return FromScopedMmap(&mmap, unilib);
    129 }
    130 
    131 void TextClassifier::ValidateAndInitialize() {
    132   initialized_ = false;
    133 
    134   if (model_ == nullptr) {
    135     TC_LOG(ERROR) << "No model specified.";
    136     return;
    137   }
    138 
    139   const bool model_enabled_for_annotation =
    140       (model_->triggering_options() != nullptr &&
    141        (model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION));
    142   const bool model_enabled_for_classification =
    143       (model_->triggering_options() != nullptr &&
    144        (model_->triggering_options()->enabled_modes() &
    145         ModeFlag_CLASSIFICATION));
    146   const bool model_enabled_for_selection =
    147       (model_->triggering_options() != nullptr &&
    148        (model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION));
    149 
    150   // Annotation requires the selection model.
    151   if (model_enabled_for_annotation || model_enabled_for_selection) {
    152     if (!model_->selection_options()) {
    153       TC_LOG(ERROR) << "No selection options.";
    154       return;
    155     }
    156     if (!model_->selection_feature_options()) {
    157       TC_LOG(ERROR) << "No selection feature options.";
    158       return;
    159     }
    160     if (!model_->selection_feature_options()->bounds_sensitive_features()) {
    161       TC_LOG(ERROR) << "No selection bounds sensitive feature options.";
    162       return;
    163     }
    164     if (!model_->selection_model()) {
    165       TC_LOG(ERROR) << "No selection model.";
    166       return;
    167     }
    168     selection_executor_ = ModelExecutor::Instance(model_->selection_model());
    169     if (!selection_executor_) {
    170       TC_LOG(ERROR) << "Could not initialize selection executor.";
    171       return;
    172     }
    173     selection_feature_processor_.reset(
    174         new FeatureProcessor(model_->selection_feature_options(), unilib_));
    175   }
    176 
    177   // Annotation requires the classification model for conflict resolution and
    178   // scoring.
    179   // Selection requires the classification model for conflict resolution.
    180   if (model_enabled_for_annotation || model_enabled_for_classification ||
    181       model_enabled_for_selection) {
    182     if (!model_->classification_options()) {
    183       TC_LOG(ERROR) << "No classification options.";
    184       return;
    185     }
    186 
    187     if (!model_->classification_feature_options()) {
    188       TC_LOG(ERROR) << "No classification feature options.";
    189       return;
    190     }
    191 
    192     if (!model_->classification_feature_options()
    193              ->bounds_sensitive_features()) {
    194       TC_LOG(ERROR) << "No classification bounds sensitive feature options.";
    195       return;
    196     }
    197     if (!model_->classification_model()) {
    198       TC_LOG(ERROR) << "No clf model.";
    199       return;
    200     }
    201 
    202     classification_executor_ =
    203         ModelExecutor::Instance(model_->classification_model());
    204     if (!classification_executor_) {
    205       TC_LOG(ERROR) << "Could not initialize classification executor.";
    206       return;
    207     }
    208 
    209     classification_feature_processor_.reset(new FeatureProcessor(
    210         model_->classification_feature_options(), unilib_));
    211   }
    212 
    213   // The embeddings need to be specified if the model is to be used for
    214   // classification or selection.
    215   if (model_enabled_for_annotation || model_enabled_for_classification ||
    216       model_enabled_for_selection) {
    217     if (!model_->embedding_model()) {
    218       TC_LOG(ERROR) << "No embedding model.";
    219       return;
    220     }
    221 
    222     // Check that the embedding size of the selection and classification model
    223     // matches, as they are using the same embeddings.
    224     if (model_enabled_for_selection &&
    225         (model_->selection_feature_options()->embedding_size() !=
    226              model_->classification_feature_options()->embedding_size() ||
    227          model_->selection_feature_options()->embedding_quantization_bits() !=
    228              model_->classification_feature_options()
    229                  ->embedding_quantization_bits())) {
    230       TC_LOG(ERROR) << "Mismatching embedding size/quantization.";
    231       return;
    232     }
    233 
    234     embedding_executor_ = TFLiteEmbeddingExecutor::Instance(
    235         model_->embedding_model(),
    236         model_->classification_feature_options()->embedding_size(),
    237         model_->classification_feature_options()
    238             ->embedding_quantization_bits());
    239     if (!embedding_executor_) {
    240       TC_LOG(ERROR) << "Could not initialize embedding executor.";
    241       return;
    242     }
    243   }
    244 
    245   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
    246   if (model_->regex_model()) {
    247     if (!InitializeRegexModel(decompressor.get())) {
    248       TC_LOG(ERROR) << "Could not initialize regex model.";
    249       return;
    250     }
    251   }
    252 
    253   if (model_->datetime_model()) {
    254     datetime_parser_ = DatetimeParser::Instance(model_->datetime_model(),
    255                                                 *unilib_, decompressor.get());
    256     if (!datetime_parser_) {
    257       TC_LOG(ERROR) << "Could not initialize datetime parser.";
    258       return;
    259     }
    260   }
    261 
    262   if (model_->output_options()) {
    263     if (model_->output_options()->filtered_collections_annotation()) {
    264       for (const auto collection :
    265            *model_->output_options()->filtered_collections_annotation()) {
    266         filtered_collections_annotation_.insert(collection->str());
    267       }
    268     }
    269     if (model_->output_options()->filtered_collections_classification()) {
    270       for (const auto collection :
    271            *model_->output_options()->filtered_collections_classification()) {
    272         filtered_collections_classification_.insert(collection->str());
    273       }
    274     }
    275     if (model_->output_options()->filtered_collections_selection()) {
    276       for (const auto collection :
    277            *model_->output_options()->filtered_collections_selection()) {
    278         filtered_collections_selection_.insert(collection->str());
    279       }
    280     }
    281   }
    282 
    283   initialized_ = true;
    284 }
    285 
    286 bool TextClassifier::InitializeRegexModel(ZlibDecompressor* decompressor) {
    287   if (!model_->regex_model()->patterns()) {
    288     return true;
    289   }
    290 
    291   // Initialize pattern recognizers.
    292   int regex_pattern_id = 0;
    293   for (const auto& regex_pattern : *model_->regex_model()->patterns()) {
    294     std::unique_ptr<UniLib::RegexPattern> compiled_pattern =
    295         UncompressMakeRegexPattern(*unilib_, regex_pattern->pattern(),
    296                                    regex_pattern->compressed_pattern(),
    297                                    decompressor);
    298     if (!compiled_pattern) {
    299       TC_LOG(INFO) << "Failed to load regex pattern";
    300       return false;
    301     }
    302 
    303     if (regex_pattern->enabled_modes() & ModeFlag_ANNOTATION) {
    304       annotation_regex_patterns_.push_back(regex_pattern_id);
    305     }
    306     if (regex_pattern->enabled_modes() & ModeFlag_CLASSIFICATION) {
    307       classification_regex_patterns_.push_back(regex_pattern_id);
    308     }
    309     if (regex_pattern->enabled_modes() & ModeFlag_SELECTION) {
    310       selection_regex_patterns_.push_back(regex_pattern_id);
    311     }
    312     regex_patterns_.push_back({regex_pattern->collection_name()->str(),
    313                                regex_pattern->target_classification_score(),
    314                                regex_pattern->priority_score(),
    315                                std::move(compiled_pattern)});
    316     if (regex_pattern->use_approximate_matching()) {
    317       regex_approximate_match_pattern_ids_.insert(regex_pattern_id);
    318     }
    319     ++regex_pattern_id;
    320   }
    321 
    322   return true;
    323 }
    324 
    325 namespace {
    326 
    327 int CountDigits(const std::string& str, CodepointSpan selection_indices) {
    328   int count = 0;
    329   int i = 0;
    330   const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
    331   for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
    332     if (i >= selection_indices.first && i < selection_indices.second &&
    333         isdigit(*it)) {
    334       ++count;
    335     }
    336   }
    337   return count;
    338 }
    339 
    340 std::string ExtractSelection(const std::string& context,
    341                              CodepointSpan selection_indices) {
    342   const UnicodeText context_unicode =
    343       UTF8ToUnicodeText(context, /*do_copy=*/false);
    344   auto selection_begin = context_unicode.begin();
    345   std::advance(selection_begin, selection_indices.first);
    346   auto selection_end = context_unicode.begin();
    347   std::advance(selection_end, selection_indices.second);
    348   return UnicodeText::UTF8Substring(selection_begin, selection_end);
    349 }
    350 }  // namespace
    351 
    352 namespace internal {
    353 // Helper function, which if the initial 'span' contains only white-spaces,
    354 // moves the selection to a single-codepoint selection on a left or right side
    355 // of this space.
    356 CodepointSpan SnapLeftIfWhitespaceSelection(CodepointSpan span,
    357                                             const UnicodeText& context_unicode,
    358                                             const UniLib& unilib) {
    359   TC_CHECK(ValidNonEmptySpan(span));
    360 
    361   UnicodeText::const_iterator it;
    362 
    363   // Check that the current selection is all whitespaces.
    364   it = context_unicode.begin();
    365   std::advance(it, span.first);
    366   for (int i = 0; i < (span.second - span.first); ++i, ++it) {
    367     if (!unilib.IsWhitespace(*it)) {
    368       return span;
    369     }
    370   }
    371 
    372   CodepointSpan result;
    373 
    374   // Try moving left.
    375   result = span;
    376   it = context_unicode.begin();
    377   std::advance(it, span.first);
    378   while (it != context_unicode.begin() && unilib.IsWhitespace(*it)) {
    379     --result.first;
    380     --it;
    381   }
    382   result.second = result.first + 1;
    383   if (!unilib.IsWhitespace(*it)) {
    384     return result;
    385   }
    386 
    387   // If moving left didn't find a non-whitespace character, just return the
    388   // original span.
    389   return span;
    390 }
    391 }  // namespace internal
    392 
    393 bool TextClassifier::FilteredForAnnotation(const AnnotatedSpan& span) const {
    394   return !span.classification.empty() &&
    395          filtered_collections_annotation_.find(
    396              span.classification[0].collection) !=
    397              filtered_collections_annotation_.end();
    398 }
    399 
    400 bool TextClassifier::FilteredForClassification(
    401     const ClassificationResult& classification) const {
    402   return filtered_collections_classification_.find(classification.collection) !=
    403          filtered_collections_classification_.end();
    404 }
    405 
    406 bool TextClassifier::FilteredForSelection(const AnnotatedSpan& span) const {
    407   return !span.classification.empty() &&
    408          filtered_collections_selection_.find(
    409              span.classification[0].collection) !=
    410              filtered_collections_selection_.end();
    411 }
    412 
    413 CodepointSpan TextClassifier::SuggestSelection(
    414     const std::string& context, CodepointSpan click_indices,
    415     const SelectionOptions& options) const {
    416   CodepointSpan original_click_indices = click_indices;
    417   if (!initialized_) {
    418     TC_LOG(ERROR) << "Not initialized";
    419     return original_click_indices;
    420   }
    421   if (!(model_->enabled_modes() & ModeFlag_SELECTION)) {
    422     return original_click_indices;
    423   }
    424 
    425   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
    426                                                         /*do_copy=*/false);
    427 
    428   if (!context_unicode.is_valid()) {
    429     return original_click_indices;
    430   }
    431 
    432   const int context_codepoint_size = context_unicode.size_codepoints();
    433 
    434   if (click_indices.first < 0 || click_indices.second < 0 ||
    435       click_indices.first >= context_codepoint_size ||
    436       click_indices.second > context_codepoint_size ||
    437       click_indices.first >= click_indices.second) {
    438     TC_VLOG(1) << "Trying to run SuggestSelection with invalid indices: "
    439                << click_indices.first << " " << click_indices.second;
    440     return original_click_indices;
    441   }
    442 
    443   if (model_->snap_whitespace_selections()) {
    444     // We want to expand a purely white-space selection to a multi-selection it
    445     // would've been part of. But with this feature disabled we would do a no-
    446     // op, because no token is found. Therefore, we need to modify the
    447     // 'click_indices' a bit to include a part of the token, so that the click-
    448     // finding logic finds the clicked token correctly. This modification is
    449     // done by the following function. Note, that it's enough to check the left
    450     // side of the current selection, because if the white-space is a part of a
    451     // multi-selection, neccessarily both tokens - on the left and the right
    452     // sides need to be selected. Thus snapping only to the left is sufficient
    453     // (there's a check at the bottom that makes sure that if we snap to the
    454     // left token but the result does not contain the initial white-space,
    455     // returns the original indices).
    456     click_indices = internal::SnapLeftIfWhitespaceSelection(
    457         click_indices, context_unicode, *unilib_);
    458   }
    459 
    460   std::vector<AnnotatedSpan> candidates;
    461   InterpreterManager interpreter_manager(selection_executor_.get(),
    462                                          classification_executor_.get());
    463   std::vector<Token> tokens;
    464   if (!ModelSuggestSelection(context_unicode, click_indices,
    465                              &interpreter_manager, &tokens, &candidates)) {
    466     TC_LOG(ERROR) << "Model suggest selection failed.";
    467     return original_click_indices;
    468   }
    469   if (!RegexChunk(context_unicode, selection_regex_patterns_, &candidates)) {
    470     TC_LOG(ERROR) << "Regex suggest selection failed.";
    471     return original_click_indices;
    472   }
    473   if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
    474                      /*reference_time_ms_utc=*/0, /*reference_timezone=*/"",
    475                      options.locales, ModeFlag_SELECTION, &candidates)) {
    476     TC_LOG(ERROR) << "Datetime suggest selection failed.";
    477     return original_click_indices;
    478   }
    479 
    480   // Sort candidates according to their position in the input, so that the next
    481   // code can assume that any connected component of overlapping spans forms a
    482   // contiguous block.
    483   std::sort(candidates.begin(), candidates.end(),
    484             [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
    485               return a.span.first < b.span.first;
    486             });
    487 
    488   std::vector<int> candidate_indices;
    489   if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager,
    490                         &candidate_indices)) {
    491     TC_LOG(ERROR) << "Couldn't resolve conflicts.";
    492     return original_click_indices;
    493   }
    494 
    495   for (const int i : candidate_indices) {
    496     if (SpansOverlap(candidates[i].span, click_indices) &&
    497         SpansOverlap(candidates[i].span, original_click_indices)) {
    498       // Run model classification if not present but requested and there's a
    499       // classification collection filter specified.
    500       if (candidates[i].classification.empty() &&
    501           model_->selection_options()->always_classify_suggested_selection() &&
    502           !filtered_collections_selection_.empty()) {
    503         if (!ModelClassifyText(
    504                 context, candidates[i].span, &interpreter_manager,
    505                 /*embedding_cache=*/nullptr, &candidates[i].classification)) {
    506           return original_click_indices;
    507         }
    508       }
    509 
    510       // Ignore if span classification is filtered.
    511       if (FilteredForSelection(candidates[i])) {
    512         return original_click_indices;
    513       }
    514 
    515       return candidates[i].span;
    516     }
    517   }
    518 
    519   return original_click_indices;
    520 }
    521 
    522 namespace {
    523 // Helper function that returns the index of the first candidate that
    524 // transitively does not overlap with the candidate on 'start_index'. If the end
    525 // of 'candidates' is reached, it returns the index that points right behind the
    526 // array.
    527 int FirstNonOverlappingSpanIndex(const std::vector<AnnotatedSpan>& candidates,
    528                                  int start_index) {
    529   int first_non_overlapping = start_index + 1;
    530   CodepointSpan conflicting_span = candidates[start_index].span;
    531   while (
    532       first_non_overlapping < candidates.size() &&
    533       SpansOverlap(conflicting_span, candidates[first_non_overlapping].span)) {
    534     // Grow the span to include the current one.
    535     conflicting_span.second = std::max(
    536         conflicting_span.second, candidates[first_non_overlapping].span.second);
    537 
    538     ++first_non_overlapping;
    539   }
    540   return first_non_overlapping;
    541 }
    542 }  // namespace
    543 
    544 bool TextClassifier::ResolveConflicts(
    545     const std::vector<AnnotatedSpan>& candidates, const std::string& context,
    546     const std::vector<Token>& cached_tokens,
    547     InterpreterManager* interpreter_manager, std::vector<int>* result) const {
    548   result->clear();
    549   result->reserve(candidates.size());
    550   for (int i = 0; i < candidates.size();) {
    551     int first_non_overlapping =
    552         FirstNonOverlappingSpanIndex(candidates, /*start_index=*/i);
    553 
    554     const bool conflict_found = first_non_overlapping != (i + 1);
    555     if (conflict_found) {
    556       std::vector<int> candidate_indices;
    557       if (!ResolveConflict(context, cached_tokens, candidates, i,
    558                            first_non_overlapping, interpreter_manager,
    559                            &candidate_indices)) {
    560         return false;
    561       }
    562       result->insert(result->end(), candidate_indices.begin(),
    563                      candidate_indices.end());
    564     } else {
    565       result->push_back(i);
    566     }
    567 
    568     // Skip over the whole conflicting group/go to next candidate.
    569     i = first_non_overlapping;
    570   }
    571   return true;
    572 }
    573 
    574 namespace {
    575 inline bool ClassifiedAsOther(
    576     const std::vector<ClassificationResult>& classification) {
    577   return !classification.empty() &&
    578          classification[0].collection == TextClassifier::kOtherCollection;
    579 }
    580 
    581 float GetPriorityScore(
    582     const std::vector<ClassificationResult>& classification) {
    583   if (!ClassifiedAsOther(classification)) {
    584     return classification[0].priority_score;
    585   } else {
    586     return -1.0;
    587   }
    588 }
    589 }  // namespace
    590 
    591 bool TextClassifier::ResolveConflict(
    592     const std::string& context, const std::vector<Token>& cached_tokens,
    593     const std::vector<AnnotatedSpan>& candidates, int start_index,
    594     int end_index, InterpreterManager* interpreter_manager,
    595     std::vector<int>* chosen_indices) const {
    596   std::vector<int> conflicting_indices;
    597   std::unordered_map<int, float> scores;
    598   for (int i = start_index; i < end_index; ++i) {
    599     conflicting_indices.push_back(i);
    600     if (!candidates[i].classification.empty()) {
    601       scores[i] = GetPriorityScore(candidates[i].classification);
    602       continue;
    603     }
    604 
    605     // OPTIMIZATION: So that we don't have to classify all the ML model
    606     // spans apriori, we wait until we get here, when they conflict with
    607     // something and we need the actual classification scores. So if the
    608     // candidate conflicts and comes from the model, we need to run a
    609     // classification to determine its priority:
    610     std::vector<ClassificationResult> classification;
    611     if (!ModelClassifyText(context, cached_tokens, candidates[i].span,
    612                            interpreter_manager,
    613                            /*embedding_cache=*/nullptr, &classification)) {
    614       return false;
    615     }
    616 
    617     if (!classification.empty()) {
    618       scores[i] = GetPriorityScore(classification);
    619     }
    620   }
    621 
    622   std::sort(conflicting_indices.begin(), conflicting_indices.end(),
    623             [&scores](int i, int j) { return scores[i] > scores[j]; });
    624 
    625   // Keeps the candidates sorted by their position in the text (their left span
    626   // index) for fast retrieval down.
    627   std::set<int, std::function<bool(int, int)>> chosen_indices_set(
    628       [&candidates](int a, int b) {
    629         return candidates[a].span.first < candidates[b].span.first;
    630       });
    631 
    632   // Greedily place the candidates if they don't conflict with the already
    633   // placed ones.
    634   for (int i = 0; i < conflicting_indices.size(); ++i) {
    635     const int considered_candidate = conflicting_indices[i];
    636     if (!DoesCandidateConflict(considered_candidate, candidates,
    637                                chosen_indices_set)) {
    638       chosen_indices_set.insert(considered_candidate);
    639     }
    640   }
    641 
    642   *chosen_indices =
    643       std::vector<int>(chosen_indices_set.begin(), chosen_indices_set.end());
    644 
    645   return true;
    646 }
    647 
    648 bool TextClassifier::ModelSuggestSelection(
    649     const UnicodeText& context_unicode, CodepointSpan click_indices,
    650     InterpreterManager* interpreter_manager, std::vector<Token>* tokens,
    651     std::vector<AnnotatedSpan>* result) const {
    652   if (model_->triggering_options() == nullptr ||
    653       !(model_->triggering_options()->enabled_modes() & ModeFlag_SELECTION)) {
    654     return true;
    655   }
    656 
    657   int click_pos;
    658   *tokens = selection_feature_processor_->Tokenize(context_unicode);
    659   selection_feature_processor_->RetokenizeAndFindClick(
    660       context_unicode, click_indices,
    661       selection_feature_processor_->GetOptions()->only_use_line_with_click(),
    662       tokens, &click_pos);
    663   if (click_pos == kInvalidIndex) {
    664     TC_VLOG(1) << "Could not calculate the click position.";
    665     return false;
    666   }
    667 
    668   const int symmetry_context_size =
    669       model_->selection_options()->symmetry_context_size();
    670   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
    671       bounds_sensitive_features = selection_feature_processor_->GetOptions()
    672                                       ->bounds_sensitive_features();
    673 
    674   // The symmetry context span is the clicked token with symmetry_context_size
    675   // tokens on either side.
    676   const TokenSpan symmetry_context_span = IntersectTokenSpans(
    677       ExpandTokenSpan(SingleTokenSpan(click_pos),
    678                       /*num_tokens_left=*/symmetry_context_size,
    679                       /*num_tokens_right=*/symmetry_context_size),
    680       {0, tokens->size()});
    681 
    682   // Compute the extraction span based on the model type.
    683   TokenSpan extraction_span;
    684   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
    685     // The extraction span is the symmetry context span expanded to include
    686     // max_selection_span tokens on either side, which is how far a selection
    687     // can stretch from the click, plus a relevant number of tokens outside of
    688     // the bounds of the selection.
    689     const int max_selection_span =
    690         selection_feature_processor_->GetOptions()->max_selection_span();
    691     extraction_span =
    692         ExpandTokenSpan(symmetry_context_span,
    693                         /*num_tokens_left=*/max_selection_span +
    694                             bounds_sensitive_features->num_tokens_before(),
    695                         /*num_tokens_right=*/max_selection_span +
    696                             bounds_sensitive_features->num_tokens_after());
    697   } else {
    698     // The extraction span is the symmetry context span expanded to include
    699     // context_size tokens on either side.
    700     const int context_size =
    701         selection_feature_processor_->GetOptions()->context_size();
    702     extraction_span = ExpandTokenSpan(symmetry_context_span,
    703                                       /*num_tokens_left=*/context_size,
    704                                       /*num_tokens_right=*/context_size);
    705   }
    706   extraction_span = IntersectTokenSpans(extraction_span, {0, tokens->size()});
    707 
    708   if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
    709           *tokens, extraction_span)) {
    710     return true;
    711   }
    712 
    713   std::unique_ptr<CachedFeatures> cached_features;
    714   if (!selection_feature_processor_->ExtractFeatures(
    715           *tokens, extraction_span,
    716           /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
    717           embedding_executor_.get(),
    718           /*embedding_cache=*/nullptr,
    719           selection_feature_processor_->EmbeddingSize() +
    720               selection_feature_processor_->DenseFeaturesCount(),
    721           &cached_features)) {
    722     TC_LOG(ERROR) << "Could not extract features.";
    723     return false;
    724   }
    725 
    726   // Produce selection model candidates.
    727   std::vector<TokenSpan> chunks;
    728   if (!ModelChunk(tokens->size(), /*span_of_interest=*/symmetry_context_span,
    729                   interpreter_manager->SelectionInterpreter(), *cached_features,
    730                   &chunks)) {
    731     TC_LOG(ERROR) << "Could not chunk.";
    732     return false;
    733   }
    734 
    735   for (const TokenSpan& chunk : chunks) {
    736     AnnotatedSpan candidate;
    737     candidate.span = selection_feature_processor_->StripBoundaryCodepoints(
    738         context_unicode, TokenSpanToCodepointSpan(*tokens, chunk));
    739     if (model_->selection_options()->strip_unpaired_brackets()) {
    740       candidate.span =
    741           StripUnpairedBrackets(context_unicode, candidate.span, *unilib_);
    742     }
    743 
    744     // Only output non-empty spans.
    745     if (candidate.span.first != candidate.span.second) {
    746       result->push_back(candidate);
    747     }
    748   }
    749   return true;
    750 }
    751 
    752 bool TextClassifier::ModelClassifyText(
    753     const std::string& context, CodepointSpan selection_indices,
    754     InterpreterManager* interpreter_manager,
    755     FeatureProcessor::EmbeddingCache* embedding_cache,
    756     std::vector<ClassificationResult>* classification_results) const {
    757   if (model_->triggering_options() == nullptr ||
    758       !(model_->triggering_options()->enabled_modes() &
    759         ModeFlag_CLASSIFICATION)) {
    760     return true;
    761   }
    762   return ModelClassifyText(context, {}, selection_indices, interpreter_manager,
    763                            embedding_cache, classification_results);
    764 }
    765 
    766 namespace internal {
    767 std::vector<Token> CopyCachedTokens(const std::vector<Token>& cached_tokens,
    768                                     CodepointSpan selection_indices,
    769                                     TokenSpan tokens_around_selection_to_copy) {
    770   const auto first_selection_token = std::upper_bound(
    771       cached_tokens.begin(), cached_tokens.end(), selection_indices.first,
    772       [](int selection_start, const Token& token) {
    773         return selection_start < token.end;
    774       });
    775   const auto last_selection_token = std::lower_bound(
    776       cached_tokens.begin(), cached_tokens.end(), selection_indices.second,
    777       [](const Token& token, int selection_end) {
    778         return token.start < selection_end;
    779       });
    780 
    781   const int64 first_token = std::max(
    782       static_cast<int64>(0),
    783       static_cast<int64>((first_selection_token - cached_tokens.begin()) -
    784                          tokens_around_selection_to_copy.first));
    785   const int64 last_token = std::min(
    786       static_cast<int64>(cached_tokens.size()),
    787       static_cast<int64>((last_selection_token - cached_tokens.begin()) +
    788                          tokens_around_selection_to_copy.second));
    789 
    790   std::vector<Token> tokens;
    791   tokens.reserve(last_token - first_token);
    792   for (int i = first_token; i < last_token; ++i) {
    793     tokens.push_back(cached_tokens[i]);
    794   }
    795   return tokens;
    796 }
    797 }  // namespace internal
    798 
    799 TokenSpan TextClassifier::ClassifyTextUpperBoundNeededTokens() const {
    800   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
    801       bounds_sensitive_features =
    802           classification_feature_processor_->GetOptions()
    803               ->bounds_sensitive_features();
    804   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
    805     // The extraction span is the selection span expanded to include a relevant
    806     // number of tokens outside of the bounds of the selection.
    807     return {bounds_sensitive_features->num_tokens_before(),
    808             bounds_sensitive_features->num_tokens_after()};
    809   } else {
    810     // The extraction span is the clicked token with context_size tokens on
    811     // either side.
    812     const int context_size =
    813         selection_feature_processor_->GetOptions()->context_size();
    814     return {context_size, context_size};
    815   }
    816 }
    817 
    818 bool TextClassifier::ModelClassifyText(
    819     const std::string& context, const std::vector<Token>& cached_tokens,
    820     CodepointSpan selection_indices, InterpreterManager* interpreter_manager,
    821     FeatureProcessor::EmbeddingCache* embedding_cache,
    822     std::vector<ClassificationResult>* classification_results) const {
    823   std::vector<Token> tokens;
    824   if (cached_tokens.empty()) {
    825     tokens = classification_feature_processor_->Tokenize(context);
    826   } else {
    827     tokens = internal::CopyCachedTokens(cached_tokens, selection_indices,
    828                                         ClassifyTextUpperBoundNeededTokens());
    829   }
    830 
    831   int click_pos;
    832   classification_feature_processor_->RetokenizeAndFindClick(
    833       context, selection_indices,
    834       classification_feature_processor_->GetOptions()
    835           ->only_use_line_with_click(),
    836       &tokens, &click_pos);
    837   const TokenSpan selection_token_span =
    838       CodepointSpanToTokenSpan(tokens, selection_indices);
    839   const int selection_num_tokens = TokenSpanSize(selection_token_span);
    840   if (model_->classification_options()->max_num_tokens() > 0 &&
    841       model_->classification_options()->max_num_tokens() <
    842           selection_num_tokens) {
    843     *classification_results = {{kOtherCollection, 1.0}};
    844     return true;
    845   }
    846 
    847   const FeatureProcessorOptions_::BoundsSensitiveFeatures*
    848       bounds_sensitive_features =
    849           classification_feature_processor_->GetOptions()
    850               ->bounds_sensitive_features();
    851   if (selection_token_span.first == kInvalidIndex ||
    852       selection_token_span.second == kInvalidIndex) {
    853     TC_LOG(ERROR) << "Could not determine span.";
    854     return false;
    855   }
    856 
    857   // Compute the extraction span based on the model type.
    858   TokenSpan extraction_span;
    859   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
    860     // The extraction span is the selection span expanded to include a relevant
    861     // number of tokens outside of the bounds of the selection.
    862     extraction_span = ExpandTokenSpan(
    863         selection_token_span,
    864         /*num_tokens_left=*/bounds_sensitive_features->num_tokens_before(),
    865         /*num_tokens_right=*/bounds_sensitive_features->num_tokens_after());
    866   } else {
    867     if (click_pos == kInvalidIndex) {
    868       TC_LOG(ERROR) << "Couldn't choose a click position.";
    869       return false;
    870     }
    871     // The extraction span is the clicked token with context_size tokens on
    872     // either side.
    873     const int context_size =
    874         classification_feature_processor_->GetOptions()->context_size();
    875     extraction_span = ExpandTokenSpan(SingleTokenSpan(click_pos),
    876                                       /*num_tokens_left=*/context_size,
    877                                       /*num_tokens_right=*/context_size);
    878   }
    879   extraction_span = IntersectTokenSpans(extraction_span, {0, tokens.size()});
    880 
    881   if (!classification_feature_processor_->HasEnoughSupportedCodepoints(
    882           tokens, extraction_span)) {
    883     *classification_results = {{kOtherCollection, 1.0}};
    884     return true;
    885   }
    886 
    887   std::unique_ptr<CachedFeatures> cached_features;
    888   if (!classification_feature_processor_->ExtractFeatures(
    889           tokens, extraction_span, selection_indices, embedding_executor_.get(),
    890           embedding_cache,
    891           classification_feature_processor_->EmbeddingSize() +
    892               classification_feature_processor_->DenseFeaturesCount(),
    893           &cached_features)) {
    894     TC_LOG(ERROR) << "Could not extract features.";
    895     return false;
    896   }
    897 
    898   std::vector<float> features;
    899   features.reserve(cached_features->OutputFeaturesSize());
    900   if (bounds_sensitive_features && bounds_sensitive_features->enabled()) {
    901     cached_features->AppendBoundsSensitiveFeaturesForSpan(selection_token_span,
    902                                                           &features);
    903   } else {
    904     cached_features->AppendClickContextFeaturesForClick(click_pos, &features);
    905   }
    906 
    907   TensorView<float> logits = classification_executor_->ComputeLogits(
    908       TensorView<float>(features.data(),
    909                         {1, static_cast<int>(features.size())}),
    910       interpreter_manager->ClassificationInterpreter());
    911   if (!logits.is_valid()) {
    912     TC_LOG(ERROR) << "Couldn't compute logits.";
    913     return false;
    914   }
    915 
    916   if (logits.dims() != 2 || logits.dim(0) != 1 ||
    917       logits.dim(1) != classification_feature_processor_->NumCollections()) {
    918     TC_LOG(ERROR) << "Mismatching output";
    919     return false;
    920   }
    921 
    922   const std::vector<float> scores =
    923       ComputeSoftmax(logits.data(), logits.dim(1));
    924 
    925   classification_results->resize(scores.size());
    926   for (int i = 0; i < scores.size(); i++) {
    927     (*classification_results)[i] = {
    928         classification_feature_processor_->LabelToCollection(i), scores[i]};
    929   }
    930   std::sort(classification_results->begin(), classification_results->end(),
    931             [](const ClassificationResult& a, const ClassificationResult& b) {
    932               return a.score > b.score;
    933             });
    934 
    935   // Phone class sanity check.
    936   if (!classification_results->empty() &&
    937       classification_results->begin()->collection == kPhoneCollection) {
    938     const int digit_count = CountDigits(context, selection_indices);
    939     if (digit_count <
    940             model_->classification_options()->phone_min_num_digits() ||
    941         digit_count >
    942             model_->classification_options()->phone_max_num_digits()) {
    943       *classification_results = {{kOtherCollection, 1.0}};
    944     }
    945   }
    946 
    947   // Address class sanity check.
    948   if (!classification_results->empty() &&
    949       classification_results->begin()->collection == kAddressCollection) {
    950     if (selection_num_tokens <
    951         model_->classification_options()->address_min_num_tokens()) {
    952       *classification_results = {{kOtherCollection, 1.0}};
    953     }
    954   }
    955 
    956   return true;
    957 }
    958 
    959 bool TextClassifier::RegexClassifyText(
    960     const std::string& context, CodepointSpan selection_indices,
    961     ClassificationResult* classification_result) const {
    962   const std::string selection_text =
    963       ExtractSelection(context, selection_indices);
    964   const UnicodeText selection_text_unicode(
    965       UTF8ToUnicodeText(selection_text, /*do_copy=*/false));
    966 
    967   // Check whether any of the regular expressions match.
    968   for (const int pattern_id : classification_regex_patterns_) {
    969     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
    970     const std::unique_ptr<UniLib::RegexMatcher> matcher =
    971         regex_pattern.pattern->Matcher(selection_text_unicode);
    972     int status = UniLib::RegexMatcher::kNoError;
    973     bool matches;
    974     if (regex_approximate_match_pattern_ids_.find(pattern_id) !=
    975         regex_approximate_match_pattern_ids_.end()) {
    976       matches = matcher->ApproximatelyMatches(&status);
    977     } else {
    978       matches = matcher->Matches(&status);
    979     }
    980     if (status != UniLib::RegexMatcher::kNoError) {
    981       return false;
    982     }
    983     if (matches) {
    984       *classification_result = {regex_pattern.collection_name,
    985                                 regex_pattern.target_classification_score,
    986                                 regex_pattern.priority_score};
    987       return true;
    988     }
    989     if (status != UniLib::RegexMatcher::kNoError) {
    990       TC_LOG(ERROR) << "Cound't match regex: " << pattern_id;
    991     }
    992   }
    993 
    994   return false;
    995 }
    996 
    997 bool TextClassifier::DatetimeClassifyText(
    998     const std::string& context, CodepointSpan selection_indices,
    999     const ClassificationOptions& options,
   1000     ClassificationResult* classification_result) const {
   1001   if (!datetime_parser_) {
   1002     return false;
   1003   }
   1004 
   1005   const std::string selection_text =
   1006       ExtractSelection(context, selection_indices);
   1007 
   1008   std::vector<DatetimeParseResultSpan> datetime_spans;
   1009   if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
   1010                                options.reference_timezone, options.locales,
   1011                                ModeFlag_CLASSIFICATION,
   1012                                /*anchor_start_end=*/true, &datetime_spans)) {
   1013     TC_LOG(ERROR) << "Error during parsing datetime.";
   1014     return false;
   1015   }
   1016   for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
   1017     // Only consider the result valid if the selection and extracted datetime
   1018     // spans exactly match.
   1019     if (std::make_pair(datetime_span.span.first + selection_indices.first,
   1020                        datetime_span.span.second + selection_indices.first) ==
   1021         selection_indices) {
   1022       *classification_result = {kDateCollection,
   1023                                 datetime_span.target_classification_score};
   1024       classification_result->datetime_parse_result = datetime_span.data;
   1025       return true;
   1026     }
   1027   }
   1028   return false;
   1029 }
   1030 
   1031 std::vector<ClassificationResult> TextClassifier::ClassifyText(
   1032     const std::string& context, CodepointSpan selection_indices,
   1033     const ClassificationOptions& options) const {
   1034   if (!initialized_) {
   1035     TC_LOG(ERROR) << "Not initialized";
   1036     return {};
   1037   }
   1038 
   1039   if (!(model_->enabled_modes() & ModeFlag_CLASSIFICATION)) {
   1040     return {};
   1041   }
   1042 
   1043   if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
   1044     return {};
   1045   }
   1046 
   1047   if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
   1048     TC_VLOG(1) << "Trying to run ClassifyText with invalid indices: "
   1049                << std::get<0>(selection_indices) << " "
   1050                << std::get<1>(selection_indices);
   1051     return {};
   1052   }
   1053 
   1054   // Try the regular expression models.
   1055   ClassificationResult regex_result;
   1056   if (RegexClassifyText(context, selection_indices, &regex_result)) {
   1057     if (!FilteredForClassification(regex_result)) {
   1058       return {regex_result};
   1059     } else {
   1060       return {{kOtherCollection, 1.0}};
   1061     }
   1062   }
   1063 
   1064   // Try the date model.
   1065   ClassificationResult datetime_result;
   1066   if (DatetimeClassifyText(context, selection_indices, options,
   1067                            &datetime_result)) {
   1068     if (!FilteredForClassification(datetime_result)) {
   1069       return {datetime_result};
   1070     } else {
   1071       return {{kOtherCollection, 1.0}};
   1072     }
   1073   }
   1074 
   1075   // Fallback to the model.
   1076   std::vector<ClassificationResult> model_result;
   1077 
   1078   InterpreterManager interpreter_manager(selection_executor_.get(),
   1079                                          classification_executor_.get());
   1080   if (ModelClassifyText(context, selection_indices, &interpreter_manager,
   1081                         /*embedding_cache=*/nullptr, &model_result) &&
   1082       !model_result.empty()) {
   1083     if (!FilteredForClassification(model_result[0])) {
   1084       return model_result;
   1085     } else {
   1086       return {{kOtherCollection, 1.0}};
   1087     }
   1088   }
   1089 
   1090   // No classifications.
   1091   return {};
   1092 }
   1093 
   1094 bool TextClassifier::ModelAnnotate(const std::string& context,
   1095                                    InterpreterManager* interpreter_manager,
   1096                                    std::vector<Token>* tokens,
   1097                                    std::vector<AnnotatedSpan>* result) const {
   1098   if (model_->triggering_options() == nullptr ||
   1099       !(model_->triggering_options()->enabled_modes() & ModeFlag_ANNOTATION)) {
   1100     return true;
   1101   }
   1102 
   1103   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
   1104                                                         /*do_copy=*/false);
   1105   std::vector<UnicodeTextRange> lines;
   1106   if (!selection_feature_processor_->GetOptions()->only_use_line_with_click()) {
   1107     lines.push_back({context_unicode.begin(), context_unicode.end()});
   1108   } else {
   1109     lines = selection_feature_processor_->SplitContext(context_unicode);
   1110   }
   1111 
   1112   const float min_annotate_confidence =
   1113       (model_->triggering_options() != nullptr
   1114            ? model_->triggering_options()->min_annotate_confidence()
   1115            : 0.f);
   1116 
   1117   FeatureProcessor::EmbeddingCache embedding_cache;
   1118   for (const UnicodeTextRange& line : lines) {
   1119     const std::string line_str =
   1120         UnicodeText::UTF8Substring(line.first, line.second);
   1121 
   1122     *tokens = selection_feature_processor_->Tokenize(line_str);
   1123     selection_feature_processor_->RetokenizeAndFindClick(
   1124         line_str, {0, std::distance(line.first, line.second)},
   1125         selection_feature_processor_->GetOptions()->only_use_line_with_click(),
   1126         tokens,
   1127         /*click_pos=*/nullptr);
   1128     const TokenSpan full_line_span = {0, tokens->size()};
   1129 
   1130     // TODO(zilka): Add support for greater granularity of this check.
   1131     if (!selection_feature_processor_->HasEnoughSupportedCodepoints(
   1132             *tokens, full_line_span)) {
   1133       continue;
   1134     }
   1135 
   1136     std::unique_ptr<CachedFeatures> cached_features;
   1137     if (!selection_feature_processor_->ExtractFeatures(
   1138             *tokens, full_line_span,
   1139             /*selection_span_for_feature=*/{kInvalidIndex, kInvalidIndex},
   1140             embedding_executor_.get(),
   1141             /*embedding_cache=*/nullptr,
   1142             selection_feature_processor_->EmbeddingSize() +
   1143                 selection_feature_processor_->DenseFeaturesCount(),
   1144             &cached_features)) {
   1145       TC_LOG(ERROR) << "Could not extract features.";
   1146       return false;
   1147     }
   1148 
   1149     std::vector<TokenSpan> local_chunks;
   1150     if (!ModelChunk(tokens->size(), /*span_of_interest=*/full_line_span,
   1151                     interpreter_manager->SelectionInterpreter(),
   1152                     *cached_features, &local_chunks)) {
   1153       TC_LOG(ERROR) << "Could not chunk.";
   1154       return false;
   1155     }
   1156 
   1157     const int offset = std::distance(context_unicode.begin(), line.first);
   1158     for (const TokenSpan& chunk : local_chunks) {
   1159       const CodepointSpan codepoint_span =
   1160           selection_feature_processor_->StripBoundaryCodepoints(
   1161               line_str, TokenSpanToCodepointSpan(*tokens, chunk));
   1162 
   1163       // Skip empty spans.
   1164       if (codepoint_span.first != codepoint_span.second) {
   1165         std::vector<ClassificationResult> classification;
   1166         if (!ModelClassifyText(line_str, *tokens, codepoint_span,
   1167                                interpreter_manager, &embedding_cache,
   1168                                &classification)) {
   1169           TC_LOG(ERROR) << "Could not classify text: "
   1170                         << (codepoint_span.first + offset) << " "
   1171                         << (codepoint_span.second + offset);
   1172           return false;
   1173         }
   1174 
   1175         // Do not include the span if it's classified as "other".
   1176         if (!classification.empty() && !ClassifiedAsOther(classification) &&
   1177             classification[0].score >= min_annotate_confidence) {
   1178           AnnotatedSpan result_span;
   1179           result_span.span = {codepoint_span.first + offset,
   1180                               codepoint_span.second + offset};
   1181           result_span.classification = std::move(classification);
   1182           result->push_back(std::move(result_span));
   1183         }
   1184       }
   1185     }
   1186   }
   1187   return true;
   1188 }
   1189 
   1190 const FeatureProcessor* TextClassifier::SelectionFeatureProcessorForTests()
   1191     const {
   1192   return selection_feature_processor_.get();
   1193 }
   1194 
   1195 const FeatureProcessor* TextClassifier::ClassificationFeatureProcessorForTests()
   1196     const {
   1197   return classification_feature_processor_.get();
   1198 }
   1199 
   1200 const DatetimeParser* TextClassifier::DatetimeParserForTests() const {
   1201   return datetime_parser_.get();
   1202 }
   1203 
   1204 std::vector<AnnotatedSpan> TextClassifier::Annotate(
   1205     const std::string& context, const AnnotationOptions& options) const {
   1206   std::vector<AnnotatedSpan> candidates;
   1207 
   1208   if (!(model_->enabled_modes() & ModeFlag_ANNOTATION)) {
   1209     return {};
   1210   }
   1211 
   1212   if (!UTF8ToUnicodeText(context, /*do_copy=*/false).is_valid()) {
   1213     return {};
   1214   }
   1215 
   1216   InterpreterManager interpreter_manager(selection_executor_.get(),
   1217                                          classification_executor_.get());
   1218   // Annotate with the selection model.
   1219   std::vector<Token> tokens;
   1220   if (!ModelAnnotate(context, &interpreter_manager, &tokens, &candidates)) {
   1221     TC_LOG(ERROR) << "Couldn't run ModelAnnotate.";
   1222     return {};
   1223   }
   1224 
   1225   // Annotate with the regular expression models.
   1226   if (!RegexChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
   1227                   annotation_regex_patterns_, &candidates)) {
   1228     TC_LOG(ERROR) << "Couldn't run RegexChunk.";
   1229     return {};
   1230   }
   1231 
   1232   // Annotate with the datetime model.
   1233   if (!DatetimeChunk(UTF8ToUnicodeText(context, /*do_copy=*/false),
   1234                      options.reference_time_ms_utc, options.reference_timezone,
   1235                      options.locales, ModeFlag_ANNOTATION, &candidates)) {
   1236     TC_LOG(ERROR) << "Couldn't run RegexChunk.";
   1237     return {};
   1238   }
   1239 
   1240   // Sort candidates according to their position in the input, so that the next
   1241   // code can assume that any connected component of overlapping spans forms a
   1242   // contiguous block.
   1243   std::sort(candidates.begin(), candidates.end(),
   1244             [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
   1245               return a.span.first < b.span.first;
   1246             });
   1247 
   1248   std::vector<int> candidate_indices;
   1249   if (!ResolveConflicts(candidates, context, tokens, &interpreter_manager,
   1250                         &candidate_indices)) {
   1251     TC_LOG(ERROR) << "Couldn't resolve conflicts.";
   1252     return {};
   1253   }
   1254 
   1255   std::vector<AnnotatedSpan> result;
   1256   result.reserve(candidate_indices.size());
   1257   for (const int i : candidate_indices) {
   1258     if (!candidates[i].classification.empty() &&
   1259         !ClassifiedAsOther(candidates[i].classification) &&
   1260         !FilteredForAnnotation(candidates[i])) {
   1261       result.push_back(std::move(candidates[i]));
   1262     }
   1263   }
   1264 
   1265   return result;
   1266 }
   1267 
   1268 bool TextClassifier::RegexChunk(const UnicodeText& context_unicode,
   1269                                 const std::vector<int>& rules,
   1270                                 std::vector<AnnotatedSpan>* result) const {
   1271   for (int pattern_id : rules) {
   1272     const CompiledRegexPattern& regex_pattern = regex_patterns_[pattern_id];
   1273     const auto matcher = regex_pattern.pattern->Matcher(context_unicode);
   1274     if (!matcher) {
   1275       TC_LOG(ERROR) << "Could not get regex matcher for pattern: "
   1276                     << pattern_id;
   1277       return false;
   1278     }
   1279 
   1280     int status = UniLib::RegexMatcher::kNoError;
   1281     while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
   1282       result->emplace_back();
   1283       // Selection/annotation regular expressions need to specify a capturing
   1284       // group specifying the selection.
   1285       result->back().span = {matcher->Start(1, &status),
   1286                              matcher->End(1, &status)};
   1287       result->back().classification = {
   1288           {regex_pattern.collection_name,
   1289            regex_pattern.target_classification_score,
   1290            regex_pattern.priority_score}};
   1291     }
   1292   }
   1293   return true;
   1294 }
   1295 
   1296 bool TextClassifier::ModelChunk(int num_tokens,
   1297                                 const TokenSpan& span_of_interest,
   1298                                 tflite::Interpreter* selection_interpreter,
   1299                                 const CachedFeatures& cached_features,
   1300                                 std::vector<TokenSpan>* chunks) const {
   1301   const int max_selection_span =
   1302       selection_feature_processor_->GetOptions()->max_selection_span();
   1303   // The inference span is the span of interest expanded to include
   1304   // max_selection_span tokens on either side, which is how far a selection can
   1305   // stretch from the click.
   1306   const TokenSpan inference_span = IntersectTokenSpans(
   1307       ExpandTokenSpan(span_of_interest,
   1308                       /*num_tokens_left=*/max_selection_span,
   1309                       /*num_tokens_right=*/max_selection_span),
   1310       {0, num_tokens});
   1311 
   1312   std::vector<ScoredChunk> scored_chunks;
   1313   if (selection_feature_processor_->GetOptions()->bounds_sensitive_features() &&
   1314       selection_feature_processor_->GetOptions()
   1315           ->bounds_sensitive_features()
   1316           ->enabled()) {
   1317     if (!ModelBoundsSensitiveScoreChunks(
   1318             num_tokens, span_of_interest, inference_span, cached_features,
   1319             selection_interpreter, &scored_chunks)) {
   1320       return false;
   1321     }
   1322   } else {
   1323     if (!ModelClickContextScoreChunks(num_tokens, span_of_interest,
   1324                                       cached_features, selection_interpreter,
   1325                                       &scored_chunks)) {
   1326       return false;
   1327     }
   1328   }
   1329   std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
   1330             [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
   1331               return lhs.score < rhs.score;
   1332             });
   1333 
   1334   // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
   1335   // them greedily as long as they do not overlap with any previously picked
   1336   // chunks.
   1337   std::vector<bool> token_used(TokenSpanSize(inference_span));
   1338   chunks->clear();
   1339   for (const ScoredChunk& scored_chunk : scored_chunks) {
   1340     bool feasible = true;
   1341     for (int i = scored_chunk.token_span.first;
   1342          i < scored_chunk.token_span.second; ++i) {
   1343       if (token_used[i - inference_span.first]) {
   1344         feasible = false;
   1345         break;
   1346       }
   1347     }
   1348 
   1349     if (!feasible) {
   1350       continue;
   1351     }
   1352 
   1353     for (int i = scored_chunk.token_span.first;
   1354          i < scored_chunk.token_span.second; ++i) {
   1355       token_used[i - inference_span.first] = true;
   1356     }
   1357 
   1358     chunks->push_back(scored_chunk.token_span);
   1359   }
   1360 
   1361   std::sort(chunks->begin(), chunks->end());
   1362 
   1363   return true;
   1364 }
   1365 
   1366 namespace {
   1367 // Updates the value at the given key in the map to maximum of the current value
   1368 // and the given value, or simply inserts the value if the key is not yet there.
   1369 template <typename Map>
   1370 void UpdateMax(Map* map, typename Map::key_type key,
   1371                typename Map::mapped_type value) {
   1372   const auto it = map->find(key);
   1373   if (it != map->end()) {
   1374     it->second = std::max(it->second, value);
   1375   } else {
   1376     (*map)[key] = value;
   1377   }
   1378 }
   1379 }  // namespace
   1380 
   1381 bool TextClassifier::ModelClickContextScoreChunks(
   1382     int num_tokens, const TokenSpan& span_of_interest,
   1383     const CachedFeatures& cached_features,
   1384     tflite::Interpreter* selection_interpreter,
   1385     std::vector<ScoredChunk>* scored_chunks) const {
   1386   const int max_batch_size = model_->selection_options()->batch_size();
   1387 
   1388   std::vector<float> all_features;
   1389   std::map<TokenSpan, float> chunk_scores;
   1390   for (int batch_start = span_of_interest.first;
   1391        batch_start < span_of_interest.second; batch_start += max_batch_size) {
   1392     const int batch_end =
   1393         std::min(batch_start + max_batch_size, span_of_interest.second);
   1394 
   1395     // Prepare features for the whole batch.
   1396     all_features.clear();
   1397     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
   1398     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
   1399       cached_features.AppendClickContextFeaturesForClick(click_pos,
   1400                                                          &all_features);
   1401     }
   1402 
   1403     // Run batched inference.
   1404     const int batch_size = batch_end - batch_start;
   1405     const int features_size = cached_features.OutputFeaturesSize();
   1406     TensorView<float> logits = selection_executor_->ComputeLogits(
   1407         TensorView<float>(all_features.data(), {batch_size, features_size}),
   1408         selection_interpreter);
   1409     if (!logits.is_valid()) {
   1410       TC_LOG(ERROR) << "Couldn't compute logits.";
   1411       return false;
   1412     }
   1413     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
   1414         logits.dim(1) !=
   1415             selection_feature_processor_->GetSelectionLabelCount()) {
   1416       TC_LOG(ERROR) << "Mismatching output.";
   1417       return false;
   1418     }
   1419 
   1420     // Save results.
   1421     for (int click_pos = batch_start; click_pos < batch_end; ++click_pos) {
   1422       const std::vector<float> scores = ComputeSoftmax(
   1423           logits.data() + logits.dim(1) * (click_pos - batch_start),
   1424           logits.dim(1));
   1425       for (int j = 0;
   1426            j < selection_feature_processor_->GetSelectionLabelCount(); ++j) {
   1427         TokenSpan relative_token_span;
   1428         if (!selection_feature_processor_->LabelToTokenSpan(
   1429                 j, &relative_token_span)) {
   1430           TC_LOG(ERROR) << "Couldn't map the label to a token span.";
   1431           return false;
   1432         }
   1433         const TokenSpan candidate_span = ExpandTokenSpan(
   1434             SingleTokenSpan(click_pos), relative_token_span.first,
   1435             relative_token_span.second);
   1436         if (candidate_span.first >= 0 && candidate_span.second <= num_tokens) {
   1437           UpdateMax(&chunk_scores, candidate_span, scores[j]);
   1438         }
   1439       }
   1440     }
   1441   }
   1442 
   1443   scored_chunks->clear();
   1444   scored_chunks->reserve(chunk_scores.size());
   1445   for (const auto& entry : chunk_scores) {
   1446     scored_chunks->push_back(ScoredChunk{entry.first, entry.second});
   1447   }
   1448 
   1449   return true;
   1450 }
   1451 
   1452 bool TextClassifier::ModelBoundsSensitiveScoreChunks(
   1453     int num_tokens, const TokenSpan& span_of_interest,
   1454     const TokenSpan& inference_span, const CachedFeatures& cached_features,
   1455     tflite::Interpreter* selection_interpreter,
   1456     std::vector<ScoredChunk>* scored_chunks) const {
   1457   const int max_selection_span =
   1458       selection_feature_processor_->GetOptions()->max_selection_span();
   1459   const int max_chunk_length = selection_feature_processor_->GetOptions()
   1460                                        ->selection_reduced_output_space()
   1461                                    ? max_selection_span + 1
   1462                                    : 2 * max_selection_span + 1;
   1463   const bool score_single_token_spans_as_zero =
   1464       selection_feature_processor_->GetOptions()
   1465           ->bounds_sensitive_features()
   1466           ->score_single_token_spans_as_zero();
   1467 
   1468   scored_chunks->clear();
   1469   if (score_single_token_spans_as_zero) {
   1470     scored_chunks->reserve(TokenSpanSize(span_of_interest));
   1471   }
   1472 
   1473   // Prepare all chunk candidates into one batch:
   1474   //   - Are contained in the inference span
   1475   //   - Have a non-empty intersection with the span of interest
   1476   //   - Are at least one token long
   1477   //   - Are not longer than the maximum chunk length
   1478   std::vector<TokenSpan> candidate_spans;
   1479   for (int start = inference_span.first; start < span_of_interest.second;
   1480        ++start) {
   1481     const int leftmost_end_index = std::max(start, span_of_interest.first) + 1;
   1482     for (int end = leftmost_end_index;
   1483          end <= inference_span.second && end - start <= max_chunk_length;
   1484          ++end) {
   1485       const TokenSpan candidate_span = {start, end};
   1486       if (score_single_token_spans_as_zero &&
   1487           TokenSpanSize(candidate_span) == 1) {
   1488         // Do not include the single token span in the batch, add a zero score
   1489         // for it directly to the output.
   1490         scored_chunks->push_back(ScoredChunk{candidate_span, 0.0f});
   1491       } else {
   1492         candidate_spans.push_back(candidate_span);
   1493       }
   1494     }
   1495   }
   1496 
   1497   const int max_batch_size = model_->selection_options()->batch_size();
   1498 
   1499   std::vector<float> all_features;
   1500   scored_chunks->reserve(scored_chunks->size() + candidate_spans.size());
   1501   for (int batch_start = 0; batch_start < candidate_spans.size();
   1502        batch_start += max_batch_size) {
   1503     const int batch_end = std::min(batch_start + max_batch_size,
   1504                                    static_cast<int>(candidate_spans.size()));
   1505 
   1506     // Prepare features for the whole batch.
   1507     all_features.clear();
   1508     all_features.reserve(max_batch_size * cached_features.OutputFeaturesSize());
   1509     for (int i = batch_start; i < batch_end; ++i) {
   1510       cached_features.AppendBoundsSensitiveFeaturesForSpan(candidate_spans[i],
   1511                                                            &all_features);
   1512     }
   1513 
   1514     // Run batched inference.
   1515     const int batch_size = batch_end - batch_start;
   1516     const int features_size = cached_features.OutputFeaturesSize();
   1517     TensorView<float> logits = selection_executor_->ComputeLogits(
   1518         TensorView<float>(all_features.data(), {batch_size, features_size}),
   1519         selection_interpreter);
   1520     if (!logits.is_valid()) {
   1521       TC_LOG(ERROR) << "Couldn't compute logits.";
   1522       return false;
   1523     }
   1524     if (logits.dims() != 2 || logits.dim(0) != batch_size ||
   1525         logits.dim(1) != 1) {
   1526       TC_LOG(ERROR) << "Mismatching output.";
   1527       return false;
   1528     }
   1529 
   1530     // Save results.
   1531     for (int i = batch_start; i < batch_end; ++i) {
   1532       scored_chunks->push_back(
   1533           ScoredChunk{candidate_spans[i], logits.data()[i - batch_start]});
   1534     }
   1535   }
   1536 
   1537   return true;
   1538 }
   1539 
   1540 bool TextClassifier::DatetimeChunk(const UnicodeText& context_unicode,
   1541                                    int64 reference_time_ms_utc,
   1542                                    const std::string& reference_timezone,
   1543                                    const std::string& locales, ModeFlag mode,
   1544                                    std::vector<AnnotatedSpan>* result) const {
   1545   if (!datetime_parser_) {
   1546     return true;
   1547   }
   1548 
   1549   std::vector<DatetimeParseResultSpan> datetime_spans;
   1550   if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
   1551                                reference_timezone, locales, mode,
   1552                                /*anchor_start_end=*/false, &datetime_spans)) {
   1553     return false;
   1554   }
   1555   for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
   1556     AnnotatedSpan annotated_span;
   1557     annotated_span.span = datetime_span.span;
   1558     annotated_span.classification = {{kDateCollection,
   1559                                       datetime_span.target_classification_score,
   1560                                       datetime_span.priority_score}};
   1561     annotated_span.classification[0].datetime_parse_result = datetime_span.data;
   1562 
   1563     result->push_back(std::move(annotated_span));
   1564   }
   1565   return true;
   1566 }
   1567 
   1568 const Model* ViewModel(const void* buffer, int size) {
   1569   if (!buffer) {
   1570     return nullptr;
   1571   }
   1572 
   1573   return LoadAndVerifyModel(buffer, size);
   1574 }
   1575 
   1576 }  // namespace libtextclassifier2
   1577