Home | History | Annotate | Download | only in annotator
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "annotator/feature-processor.h"
     18 
     19 #include <iterator>
     20 #include <set>
     21 #include <vector>
     22 
     23 #include "utils/base/logging.h"
     24 #include "utils/strings/utf8.h"
     25 #include "utils/utf8/unicodetext.h"
     26 
     27 namespace libtextclassifier3 {
     28 
     29 namespace internal {
     30 
     31 Tokenizer BuildTokenizer(const FeatureProcessorOptions* options,
     32                          const UniLib* unilib) {
     33   std::vector<const TokenizationCodepointRange*> codepoint_config;
     34   if (options->tokenization_codepoint_config() != nullptr) {
     35     codepoint_config.insert(codepoint_config.end(),
     36                             options->tokenization_codepoint_config()->begin(),
     37                             options->tokenization_codepoint_config()->end());
     38   }
     39   std::vector<const CodepointRange*> internal_codepoint_config;
     40   if (options->internal_tokenizer_codepoint_ranges() != nullptr) {
     41     internal_codepoint_config.insert(
     42         internal_codepoint_config.end(),
     43         options->internal_tokenizer_codepoint_ranges()->begin(),
     44         options->internal_tokenizer_codepoint_ranges()->end());
     45   }
     46   const bool tokenize_on_script_change =
     47       options->tokenization_codepoint_config() != nullptr &&
     48       options->tokenize_on_script_change();
     49   return Tokenizer(options->tokenization_type(), unilib, codepoint_config,
     50                    internal_codepoint_config, tokenize_on_script_change,
     51                    options->icu_preserve_whitespace_tokens());
     52 }
     53 
     54 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
     55     const FeatureProcessorOptions* const options) {
     56   TokenFeatureExtractorOptions extractor_options;
     57 
     58   extractor_options.num_buckets = options->num_buckets();
     59   if (options->chargram_orders() != nullptr) {
     60     for (int order : *options->chargram_orders()) {
     61       extractor_options.chargram_orders.push_back(order);
     62     }
     63   }
     64   extractor_options.max_word_length = options->max_word_length();
     65   extractor_options.extract_case_feature = options->extract_case_feature();
     66   extractor_options.unicode_aware_features = options->unicode_aware_features();
     67   extractor_options.extract_selection_mask_feature =
     68       options->extract_selection_mask_feature();
     69   if (options->regexp_feature() != nullptr) {
     70     for (const auto& regexp_feauture : *options->regexp_feature()) {
     71       extractor_options.regexp_features.push_back(regexp_feauture->str());
     72     }
     73   }
     74   extractor_options.remap_digits = options->remap_digits();
     75   extractor_options.lowercase_tokens = options->lowercase_tokens();
     76 
     77   if (options->allowed_chargrams() != nullptr) {
     78     for (const auto& chargram : *options->allowed_chargrams()) {
     79       extractor_options.allowed_chargrams.insert(chargram->str());
     80     }
     81   }
     82   return extractor_options;
     83 }
     84 
     85 void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
     86                                       std::vector<Token>* tokens) {
     87   for (auto it = tokens->begin(); it != tokens->end(); ++it) {
     88     const UnicodeText token_word =
     89         UTF8ToUnicodeText(it->value, /*do_copy=*/false);
     90 
     91     auto last_start = token_word.begin();
     92     int last_start_index = it->start;
     93     std::vector<UnicodeText::const_iterator> split_points;
     94 
     95     // Selection start split point.
     96     if (selection.first > it->start && selection.first < it->end) {
     97       std::advance(last_start, selection.first - last_start_index);
     98       split_points.push_back(last_start);
     99       last_start_index = selection.first;
    100     }
    101 
    102     // Selection end split point.
    103     if (selection.second > it->start && selection.second < it->end) {
    104       std::advance(last_start, selection.second - last_start_index);
    105       split_points.push_back(last_start);
    106     }
    107 
    108     if (!split_points.empty()) {
    109       // Add a final split for the rest of the token unless it's been all
    110       // consumed already.
    111       if (split_points.back() != token_word.end()) {
    112         split_points.push_back(token_word.end());
    113       }
    114 
    115       std::vector<Token> replacement_tokens;
    116       last_start = token_word.begin();
    117       int current_pos = it->start;
    118       for (const auto& split_point : split_points) {
    119         Token new_token(token_word.UTF8Substring(last_start, split_point),
    120                         current_pos,
    121                         current_pos + std::distance(last_start, split_point));
    122 
    123         last_start = split_point;
    124         current_pos = new_token.end;
    125 
    126         replacement_tokens.push_back(new_token);
    127       }
    128 
    129       it = tokens->erase(it);
    130       it = tokens->insert(it, replacement_tokens.begin(),
    131                           replacement_tokens.end());
    132       std::advance(it, replacement_tokens.size() - 1);
    133     }
    134   }
    135 }
    136 
    137 }  // namespace internal
    138 
    139 void FeatureProcessor::StripTokensFromOtherLines(
    140     const std::string& context, CodepointSpan span,
    141     std::vector<Token>* tokens) const {
    142   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
    143                                                         /*do_copy=*/false);
    144   StripTokensFromOtherLines(context_unicode, span, tokens);
    145 }
    146 
    147 void FeatureProcessor::StripTokensFromOtherLines(
    148     const UnicodeText& context_unicode, CodepointSpan span,
    149     std::vector<Token>* tokens) const {
    150   std::vector<UnicodeTextRange> lines = SplitContext(context_unicode);
    151 
    152   auto span_start = context_unicode.begin();
    153   if (span.first > 0) {
    154     std::advance(span_start, span.first);
    155   }
    156   auto span_end = context_unicode.begin();
    157   if (span.second > 0) {
    158     std::advance(span_end, span.second);
    159   }
    160   for (const UnicodeTextRange& line : lines) {
    161     // Find the line that completely contains the span.
    162     if (line.first <= span_start && line.second >= span_end) {
    163       const CodepointIndex last_line_begin_index =
    164           std::distance(context_unicode.begin(), line.first);
    165       const CodepointIndex last_line_end_index =
    166           last_line_begin_index + std::distance(line.first, line.second);
    167 
    168       for (auto token = tokens->begin(); token != tokens->end();) {
    169         if (token->start >= last_line_begin_index &&
    170             token->end <= last_line_end_index) {
    171           ++token;
    172         } else {
    173           token = tokens->erase(token);
    174         }
    175       }
    176     }
    177   }
    178 }
    179 
    180 std::string FeatureProcessor::GetDefaultCollection() const {
    181   if (options_->default_collection() < 0 ||
    182       options_->collections() == nullptr ||
    183       options_->default_collection() >= options_->collections()->size()) {
    184     TC3_LOG(ERROR)
    185         << "Invalid or missing default collection. Returning empty string.";
    186     return "";
    187   }
    188   return (*options_->collections())[options_->default_collection()]->str();
    189 }
    190 
    191 std::vector<Token> FeatureProcessor::Tokenize(const std::string& text) const {
    192   return tokenizer_.Tokenize(text);
    193 }
    194 
    195 std::vector<Token> FeatureProcessor::Tokenize(
    196     const UnicodeText& text_unicode) const {
    197   return tokenizer_.Tokenize(text_unicode);
    198 }
    199 
    200 bool FeatureProcessor::LabelToSpan(
    201     const int label, const VectorSpan<Token>& tokens,
    202     std::pair<CodepointIndex, CodepointIndex>* span) const {
    203   if (tokens.size() != GetNumContextTokens()) {
    204     return false;
    205   }
    206 
    207   TokenSpan token_span;
    208   if (!LabelToTokenSpan(label, &token_span)) {
    209     return false;
    210   }
    211 
    212   const int result_begin_token_index = token_span.first;
    213   const Token& result_begin_token =
    214       tokens[options_->context_size() - result_begin_token_index];
    215   const int result_begin_codepoint = result_begin_token.start;
    216   const int result_end_token_index = token_span.second;
    217   const Token& result_end_token =
    218       tokens[options_->context_size() + result_end_token_index];
    219   const int result_end_codepoint = result_end_token.end;
    220 
    221   if (result_begin_codepoint == kInvalidIndex ||
    222       result_end_codepoint == kInvalidIndex) {
    223     *span = CodepointSpan({kInvalidIndex, kInvalidIndex});
    224   } else {
    225     const UnicodeText token_begin_unicode =
    226         UTF8ToUnicodeText(result_begin_token.value, /*do_copy=*/false);
    227     UnicodeText::const_iterator token_begin = token_begin_unicode.begin();
    228     const UnicodeText token_end_unicode =
    229         UTF8ToUnicodeText(result_end_token.value, /*do_copy=*/false);
    230     UnicodeText::const_iterator token_end = token_end_unicode.end();
    231 
    232     const int begin_ignored = CountIgnoredSpanBoundaryCodepoints(
    233         token_begin, token_begin_unicode.end(),
    234         /*count_from_beginning=*/true);
    235     const int end_ignored =
    236         CountIgnoredSpanBoundaryCodepoints(token_end_unicode.begin(), token_end,
    237                                            /*count_from_beginning=*/false);
    238     // In case everything would be stripped, set the span to the original
    239     // beginning and zero length.
    240     if (begin_ignored == (result_end_codepoint - result_begin_codepoint)) {
    241       *span = {result_begin_codepoint, result_begin_codepoint};
    242     } else {
    243       *span = CodepointSpan({result_begin_codepoint + begin_ignored,
    244                              result_end_codepoint - end_ignored});
    245     }
    246   }
    247   return true;
    248 }
    249 
    250 bool FeatureProcessor::LabelToTokenSpan(const int label,
    251                                         TokenSpan* token_span) const {
    252   if (label >= 0 && label < label_to_selection_.size()) {
    253     *token_span = label_to_selection_[label];
    254     return true;
    255   } else {
    256     return false;
    257   }
    258 }
    259 
    260 bool FeatureProcessor::SpanToLabel(
    261     const std::pair<CodepointIndex, CodepointIndex>& span,
    262     const std::vector<Token>& tokens, int* label) const {
    263   if (tokens.size() != GetNumContextTokens()) {
    264     return false;
    265   }
    266 
    267   const int click_position =
    268       options_->context_size();  // Click is always in the middle.
    269   const int padding = options_->context_size() - options_->max_selection_span();
    270 
    271   int span_left = 0;
    272   for (int i = click_position - 1; i >= padding; i--) {
    273     if (tokens[i].start != kInvalidIndex && tokens[i].end > span.first) {
    274       ++span_left;
    275     } else {
    276       break;
    277     }
    278   }
    279 
    280   int span_right = 0;
    281   for (int i = click_position + 1; i < tokens.size() - padding; ++i) {
    282     if (tokens[i].end != kInvalidIndex && tokens[i].start < span.second) {
    283       ++span_right;
    284     } else {
    285       break;
    286     }
    287   }
    288 
    289   // Check that the spanned tokens cover the whole span.
    290   bool tokens_match_span;
    291   const CodepointIndex tokens_start = tokens[click_position - span_left].start;
    292   const CodepointIndex tokens_end = tokens[click_position + span_right].end;
    293   if (options_->snap_label_span_boundaries_to_containing_tokens()) {
    294     tokens_match_span = tokens_start <= span.first && tokens_end >= span.second;
    295   } else {
    296     const UnicodeText token_left_unicode = UTF8ToUnicodeText(
    297         tokens[click_position - span_left].value, /*do_copy=*/false);
    298     const UnicodeText token_right_unicode = UTF8ToUnicodeText(
    299         tokens[click_position + span_right].value, /*do_copy=*/false);
    300 
    301     UnicodeText::const_iterator span_begin = token_left_unicode.begin();
    302     UnicodeText::const_iterator span_end = token_right_unicode.end();
    303 
    304     const int num_punctuation_start = CountIgnoredSpanBoundaryCodepoints(
    305         span_begin, token_left_unicode.end(), /*count_from_beginning=*/true);
    306     const int num_punctuation_end = CountIgnoredSpanBoundaryCodepoints(
    307         token_right_unicode.begin(), span_end,
    308         /*count_from_beginning=*/false);
    309 
    310     tokens_match_span = tokens_start <= span.first &&
    311                         tokens_start + num_punctuation_start >= span.first &&
    312                         tokens_end >= span.second &&
    313                         tokens_end - num_punctuation_end <= span.second;
    314   }
    315 
    316   if (tokens_match_span) {
    317     *label = TokenSpanToLabel({span_left, span_right});
    318   } else {
    319     *label = kInvalidLabel;
    320   }
    321 
    322   return true;
    323 }
    324 
    325 int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const {
    326   auto it = selection_to_label_.find(span);
    327   if (it != selection_to_label_.end()) {
    328     return it->second;
    329   } else {
    330     return kInvalidLabel;
    331   }
    332 }
    333 
    334 TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
    335                                    CodepointSpan codepoint_span,
    336                                    bool snap_boundaries_to_containing_tokens) {
    337   const int codepoint_start = std::get<0>(codepoint_span);
    338   const int codepoint_end = std::get<1>(codepoint_span);
    339 
    340   TokenIndex start_token = kInvalidIndex;
    341   TokenIndex end_token = kInvalidIndex;
    342   for (int i = 0; i < selectable_tokens.size(); ++i) {
    343     bool is_token_in_span;
    344     if (snap_boundaries_to_containing_tokens) {
    345       is_token_in_span = codepoint_start < selectable_tokens[i].end &&
    346                          codepoint_end > selectable_tokens[i].start;
    347     } else {
    348       is_token_in_span = codepoint_start <= selectable_tokens[i].start &&
    349                          codepoint_end >= selectable_tokens[i].end;
    350     }
    351     if (is_token_in_span && !selectable_tokens[i].is_padding) {
    352       if (start_token == kInvalidIndex) {
    353         start_token = i;
    354       }
    355       end_token = i + 1;
    356     }
    357   }
    358   return {start_token, end_token};
    359 }
    360 
    361 CodepointSpan TokenSpanToCodepointSpan(
    362     const std::vector<Token>& selectable_tokens, TokenSpan token_span) {
    363   return {selectable_tokens[token_span.first].start,
    364           selectable_tokens[token_span.second - 1].end};
    365 }
    366 
    367 namespace {
    368 
    369 // Finds a single token that completely contains the given span.
    370 int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
    371                               CodepointSpan codepoint_span) {
    372   const int codepoint_start = std::get<0>(codepoint_span);
    373   const int codepoint_end = std::get<1>(codepoint_span);
    374 
    375   for (int i = 0; i < selectable_tokens.size(); ++i) {
    376     if (codepoint_start >= selectable_tokens[i].start &&
    377         codepoint_end <= selectable_tokens[i].end) {
    378       return i;
    379     }
    380   }
    381   return kInvalidIndex;
    382 }
    383 
    384 }  // namespace
    385 
    386 namespace internal {
    387 
    388 int CenterTokenFromClick(CodepointSpan span,
    389                          const std::vector<Token>& selectable_tokens) {
    390   int range_begin;
    391   int range_end;
    392   std::tie(range_begin, range_end) =
    393       CodepointSpanToTokenSpan(selectable_tokens, span);
    394 
    395   // If no exact match was found, try finding a token that completely contains
    396   // the click span. This is useful e.g. when Android builds the selection
    397   // using ICU tokenization, and ends up with only a portion of our space-
    398   // separated token. E.g. for "(857)" Android would select "857".
    399   if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
    400     int token_index = FindTokenThatContainsSpan(selectable_tokens, span);
    401     if (token_index != kInvalidIndex) {
    402       range_begin = token_index;
    403       range_end = token_index + 1;
    404     }
    405   }
    406 
    407   // We only allow clicks that are exactly 1 selectable token.
    408   if (range_end - range_begin == 1) {
    409     return range_begin;
    410   } else {
    411     return kInvalidIndex;
    412   }
    413 }
    414 
    415 int CenterTokenFromMiddleOfSelection(
    416     CodepointSpan span, const std::vector<Token>& selectable_tokens) {
    417   int range_begin;
    418   int range_end;
    419   std::tie(range_begin, range_end) =
    420       CodepointSpanToTokenSpan(selectable_tokens, span);
    421 
    422   // Center the clicked token in the selection range.
    423   if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
    424     return (range_begin + range_end - 1) / 2;
    425   } else {
    426     return kInvalidIndex;
    427   }
    428 }
    429 
    430 }  // namespace internal
    431 
    432 int FeatureProcessor::FindCenterToken(CodepointSpan span,
    433                                       const std::vector<Token>& tokens) const {
    434   if (options_->center_token_selection_method() ==
    435       FeatureProcessorOptions_::
    436           CenterTokenSelectionMethod_CENTER_TOKEN_FROM_CLICK) {
    437     return internal::CenterTokenFromClick(span, tokens);
    438   } else if (options_->center_token_selection_method() ==
    439              FeatureProcessorOptions_::
    440                  CenterTokenSelectionMethod_CENTER_TOKEN_MIDDLE_OF_SELECTION) {
    441     return internal::CenterTokenFromMiddleOfSelection(span, tokens);
    442   } else if (options_->center_token_selection_method() ==
    443              FeatureProcessorOptions_::
    444                  CenterTokenSelectionMethod_DEFAULT_CENTER_TOKEN_METHOD) {
    445     // TODO(zilka): Remove once we have new models on the device.
    446     // It uses the fact that sharing model use
    447     // split_tokens_on_selection_boundaries and selection not. So depending on
    448     // this we select the right way of finding the click location.
    449     if (!options_->split_tokens_on_selection_boundaries()) {
    450       // SmartSelection model.
    451       return internal::CenterTokenFromClick(span, tokens);
    452     } else {
    453       // SmartSharing model.
    454       return internal::CenterTokenFromMiddleOfSelection(span, tokens);
    455     }
    456   } else {
    457     TC3_LOG(ERROR) << "Invalid center token selection method.";
    458     return kInvalidIndex;
    459   }
    460 }
    461 
    462 bool FeatureProcessor::SelectionLabelSpans(
    463     const VectorSpan<Token> tokens,
    464     std::vector<CodepointSpan>* selection_label_spans) const {
    465   for (int i = 0; i < label_to_selection_.size(); ++i) {
    466     CodepointSpan span;
    467     if (!LabelToSpan(i, tokens, &span)) {
    468       TC3_LOG(ERROR) << "Could not convert label to span: " << i;
    469       return false;
    470     }
    471     selection_label_spans->push_back(span);
    472   }
    473   return true;
    474 }
    475 
    476 void FeatureProcessor::PrepareIgnoredSpanBoundaryCodepoints() {
    477   if (options_->ignored_span_boundary_codepoints() != nullptr) {
    478     for (const int codepoint : *options_->ignored_span_boundary_codepoints()) {
    479       ignored_span_boundary_codepoints_.insert(codepoint);
    480     }
    481   }
    482 }
    483 
    484 int FeatureProcessor::CountIgnoredSpanBoundaryCodepoints(
    485     const UnicodeText::const_iterator& span_start,
    486     const UnicodeText::const_iterator& span_end,
    487     bool count_from_beginning) const {
    488   if (span_start == span_end) {
    489     return 0;
    490   }
    491 
    492   UnicodeText::const_iterator it;
    493   UnicodeText::const_iterator it_last;
    494   if (count_from_beginning) {
    495     it = span_start;
    496     it_last = span_end;
    497     // We can assume that the string is non-zero length because of the check
    498     // above, thus the decrement is always valid here.
    499     --it_last;
    500   } else {
    501     it = span_end;
    502     it_last = span_start;
    503     // We can assume that the string is non-zero length because of the check
    504     // above, thus the decrement is always valid here.
    505     --it;
    506   }
    507 
    508   // Move until we encounter a non-ignored character.
    509   int num_ignored = 0;
    510   while (ignored_span_boundary_codepoints_.find(*it) !=
    511          ignored_span_boundary_codepoints_.end()) {
    512     ++num_ignored;
    513 
    514     if (it == it_last) {
    515       break;
    516     }
    517 
    518     if (count_from_beginning) {
    519       ++it;
    520     } else {
    521       --it;
    522     }
    523   }
    524 
    525   return num_ignored;
    526 }
    527 
    528 namespace {
    529 
    530 void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
    531                     std::vector<UnicodeTextRange>* ranges) {
    532   UnicodeText::const_iterator start = t.begin();
    533   UnicodeText::const_iterator curr = start;
    534   UnicodeText::const_iterator end = t.end();
    535   for (; curr != end; ++curr) {
    536     if (codepoints.find(*curr) != codepoints.end()) {
    537       if (start != curr) {
    538         ranges->push_back(std::make_pair(start, curr));
    539       }
    540       start = curr;
    541       ++start;
    542     }
    543   }
    544   if (start != end) {
    545     ranges->push_back(std::make_pair(start, end));
    546   }
    547 }
    548 
    549 }  // namespace
    550 
    551 std::vector<UnicodeTextRange> FeatureProcessor::SplitContext(
    552     const UnicodeText& context_unicode) const {
    553   std::vector<UnicodeTextRange> lines;
    554   const std::set<char32> codepoints{{'\n', '|'}};
    555   FindSubstrings(context_unicode, codepoints, &lines);
    556   return lines;
    557 }
    558 
    559 CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
    560     const std::string& context, CodepointSpan span) const {
    561   const UnicodeText context_unicode =
    562       UTF8ToUnicodeText(context, /*do_copy=*/false);
    563   return StripBoundaryCodepoints(context_unicode, span);
    564 }
    565 
    566 CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
    567     const UnicodeText& context_unicode, CodepointSpan span) const {
    568   if (context_unicode.empty() || !ValidNonEmptySpan(span)) {
    569     return span;
    570   }
    571 
    572   UnicodeText::const_iterator span_begin = context_unicode.begin();
    573   std::advance(span_begin, span.first);
    574   UnicodeText::const_iterator span_end = context_unicode.begin();
    575   std::advance(span_end, span.second);
    576 
    577   return StripBoundaryCodepoints(span_begin, span_end, span);
    578 }
    579 
    580 CodepointSpan FeatureProcessor::StripBoundaryCodepoints(
    581     const UnicodeText::const_iterator& span_begin,
    582     const UnicodeText::const_iterator& span_end, CodepointSpan span) const {
    583   if (!ValidNonEmptySpan(span) || span_begin == span_end) {
    584     return span;
    585   }
    586 
    587   const int start_offset = CountIgnoredSpanBoundaryCodepoints(
    588       span_begin, span_end, /*count_from_beginning=*/true);
    589   const int end_offset = CountIgnoredSpanBoundaryCodepoints(
    590       span_begin, span_end, /*count_from_beginning=*/false);
    591 
    592   if (span.first + start_offset < span.second - end_offset) {
    593     return {span.first + start_offset, span.second - end_offset};
    594   } else {
    595     return {span.first, span.first};
    596   }
    597 }
    598 
    599 float FeatureProcessor::SupportedCodepointsRatio(
    600     const TokenSpan& token_span, const std::vector<Token>& tokens) const {
    601   int num_supported = 0;
    602   int num_total = 0;
    603   for (int i = token_span.first; i < token_span.second; ++i) {
    604     const UnicodeText value =
    605         UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false);
    606     for (auto codepoint : value) {
    607       if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) {
    608         ++num_supported;
    609       }
    610       ++num_total;
    611     }
    612   }
    613   return static_cast<float>(num_supported) / static_cast<float>(num_total);
    614 }
    615 
    616 const std::string& FeatureProcessor::StripBoundaryCodepoints(
    617     const std::string& value, std::string* buffer) const {
    618   const UnicodeText value_unicode = UTF8ToUnicodeText(value, /*do_copy=*/false);
    619   const CodepointSpan initial_span{0, value_unicode.size_codepoints()};
    620   const CodepointSpan stripped_span =
    621       StripBoundaryCodepoints(value_unicode, initial_span);
    622 
    623   if (initial_span != stripped_span) {
    624     const UnicodeText stripped_token_value =
    625         UnicodeText::Substring(value_unicode, stripped_span.first,
    626                                stripped_span.second, /*do_copy=*/false);
    627     *buffer = stripped_token_value.ToUTF8String();
    628     return *buffer;
    629   }
    630   return value;
    631 }
    632 
    633 int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
    634   const auto it = collection_to_label_.find(collection);
    635   if (it == collection_to_label_.end()) {
    636     return options_->default_collection();
    637   } else {
    638     return it->second;
    639   }
    640 }
    641 
    642 std::string FeatureProcessor::LabelToCollection(int label) const {
    643   if (label >= 0 && label < collection_to_label_.size()) {
    644     return (*options_->collections())[label]->str();
    645   } else {
    646     return GetDefaultCollection();
    647   }
    648 }
    649 
    650 void FeatureProcessor::MakeLabelMaps() {
    651   if (options_->collections() != nullptr) {
    652     for (int i = 0; i < options_->collections()->size(); ++i) {
    653       collection_to_label_[(*options_->collections())[i]->str()] = i;
    654     }
    655   }
    656 
    657   int selection_label_id = 0;
    658   for (int l = 0; l < (options_->max_selection_span() + 1); ++l) {
    659     for (int r = 0; r < (options_->max_selection_span() + 1); ++r) {
    660       if (!options_->selection_reduced_output_space() ||
    661           r + l <= options_->max_selection_span()) {
    662         TokenSpan token_span{l, r};
    663         selection_to_label_[token_span] = selection_label_id;
    664         label_to_selection_.push_back(token_span);
    665         ++selection_label_id;
    666       }
    667     }
    668   }
    669 }
    670 
    671 void FeatureProcessor::RetokenizeAndFindClick(const std::string& context,
    672                                               CodepointSpan input_span,
    673                                               bool only_use_line_with_click,
    674                                               std::vector<Token>* tokens,
    675                                               int* click_pos) const {
    676   const UnicodeText context_unicode =
    677       UTF8ToUnicodeText(context, /*do_copy=*/false);
    678   RetokenizeAndFindClick(context_unicode, input_span, only_use_line_with_click,
    679                          tokens, click_pos);
    680 }
    681 
    682 void FeatureProcessor::RetokenizeAndFindClick(
    683     const UnicodeText& context_unicode, CodepointSpan input_span,
    684     bool only_use_line_with_click, std::vector<Token>* tokens,
    685     int* click_pos) const {
    686   TC3_CHECK(tokens != nullptr);
    687 
    688   if (options_->split_tokens_on_selection_boundaries()) {
    689     internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
    690   }
    691 
    692   if (only_use_line_with_click) {
    693     StripTokensFromOtherLines(context_unicode, input_span, tokens);
    694   }
    695 
    696   int local_click_pos;
    697   if (click_pos == nullptr) {
    698     click_pos = &local_click_pos;
    699   }
    700   *click_pos = FindCenterToken(input_span, *tokens);
    701   if (*click_pos == kInvalidIndex) {
    702     // If the default click method failed, let's try to do sub-token matching
    703     // before we fail.
    704     *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
    705   }
    706 }
    707 
    708 namespace internal {
    709 
    710 void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
    711                       std::vector<Token>* tokens, int* click_pos) {
    712   int right_context_needed = relative_click_span.second + context_size;
    713   if (*click_pos + right_context_needed + 1 >= tokens->size()) {
    714     // Pad max the context size.
    715     const int num_pad_tokens = std::min(
    716         context_size, static_cast<int>(*click_pos + right_context_needed + 1 -
    717                                        tokens->size()));
    718     std::vector<Token> pad_tokens(num_pad_tokens);
    719     tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
    720   } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) {
    721     // Strip unused tokens.
    722     auto it = tokens->begin();
    723     std::advance(it, *click_pos + right_context_needed + 1);
    724     tokens->erase(it, tokens->end());
    725   }
    726 
    727   int left_context_needed = relative_click_span.first + context_size;
    728   if (*click_pos < left_context_needed) {
    729     // Pad max the context size.
    730     const int num_pad_tokens =
    731         std::min(context_size, left_context_needed - *click_pos);
    732     std::vector<Token> pad_tokens(num_pad_tokens);
    733     tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
    734     *click_pos += num_pad_tokens;
    735   } else if (*click_pos > left_context_needed) {
    736     // Strip unused tokens.
    737     auto it = tokens->begin();
    738     std::advance(it, *click_pos - left_context_needed);
    739     *click_pos -= it - tokens->begin();
    740     tokens->erase(tokens->begin(), it);
    741   }
    742 }
    743 
    744 }  // namespace internal
    745 
    746 bool FeatureProcessor::HasEnoughSupportedCodepoints(
    747     const std::vector<Token>& tokens, TokenSpan token_span) const {
    748   if (options_->min_supported_codepoint_ratio() > 0) {
    749     const float supported_codepoint_ratio =
    750         SupportedCodepointsRatio(token_span, tokens);
    751     if (supported_codepoint_ratio < options_->min_supported_codepoint_ratio()) {
    752       TC3_VLOG(1) << "Not enough supported codepoints in the context: "
    753                   << supported_codepoint_ratio;
    754       return false;
    755     }
    756   }
    757   return true;
    758 }
    759 
    760 bool FeatureProcessor::ExtractFeatures(
    761     const std::vector<Token>& tokens, TokenSpan token_span,
    762     CodepointSpan selection_span_for_feature,
    763     const EmbeddingExecutor* embedding_executor,
    764     EmbeddingCache* embedding_cache, int feature_vector_size,
    765     std::unique_ptr<CachedFeatures>* cached_features) const {
    766   std::unique_ptr<std::vector<float>> features(new std::vector<float>());
    767   features->reserve(feature_vector_size * TokenSpanSize(token_span));
    768   for (int i = token_span.first; i < token_span.second; ++i) {
    769     if (!AppendTokenFeaturesWithCache(tokens[i], selection_span_for_feature,
    770                                       embedding_executor, embedding_cache,
    771                                       features.get())) {
    772       TC3_LOG(ERROR) << "Could not get token features.";
    773       return false;
    774     }
    775   }
    776 
    777   std::unique_ptr<std::vector<float>> padding_features(
    778       new std::vector<float>());
    779   padding_features->reserve(feature_vector_size);
    780   if (!AppendTokenFeaturesWithCache(Token(), selection_span_for_feature,
    781                                     embedding_executor, embedding_cache,
    782                                     padding_features.get())) {
    783     TC3_LOG(ERROR) << "Count not get padding token features.";
    784     return false;
    785   }
    786 
    787   *cached_features = CachedFeatures::Create(token_span, std::move(features),
    788                                             std::move(padding_features),
    789                                             options_, feature_vector_size);
    790   if (!*cached_features) {
    791     TC3_LOG(ERROR) << "Cound not create cached features.";
    792     return false;
    793   }
    794 
    795   return true;
    796 }
    797 
    798 bool FeatureProcessor::AppendTokenFeaturesWithCache(
    799     const Token& token, CodepointSpan selection_span_for_feature,
    800     const EmbeddingExecutor* embedding_executor,
    801     EmbeddingCache* embedding_cache,
    802     std::vector<float>* output_features) const {
    803   // Look for the embedded features for the token in the cache, if there is one.
    804   if (embedding_cache) {
    805     const auto it = embedding_cache->find({token.start, token.end});
    806     if (it != embedding_cache->end()) {
    807       // The embedded features were found in the cache, extract only the dense
    808       // features.
    809       std::vector<float> dense_features;
    810       if (!feature_extractor_.Extract(
    811               token, token.IsContainedInSpan(selection_span_for_feature),
    812               /*sparse_features=*/nullptr, &dense_features)) {
    813         TC3_LOG(ERROR) << "Could not extract token's dense features.";
    814         return false;
    815       }
    816 
    817       // Append both embedded and dense features to the output and return.
    818       output_features->insert(output_features->end(), it->second.begin(),
    819                               it->second.end());
    820       output_features->insert(output_features->end(), dense_features.begin(),
    821                               dense_features.end());
    822       return true;
    823     }
    824   }
    825 
    826   // Extract the sparse and dense features.
    827   std::vector<int> sparse_features;
    828   std::vector<float> dense_features;
    829   if (!feature_extractor_.Extract(
    830           token, token.IsContainedInSpan(selection_span_for_feature),
    831           &sparse_features, &dense_features)) {
    832     TC3_LOG(ERROR) << "Could not extract token's features.";
    833     return false;
    834   }
    835 
    836   // Embed the sparse features, appending them directly to the output.
    837   const int embedding_size = GetOptions()->embedding_size();
    838   output_features->resize(output_features->size() + embedding_size);
    839   float* output_features_end =
    840       output_features->data() + output_features->size();
    841   if (!embedding_executor->AddEmbedding(
    842           TensorView<int>(sparse_features.data(),
    843                           {static_cast<int>(sparse_features.size())}),
    844           /*dest=*/output_features_end - embedding_size,
    845           /*dest_size=*/embedding_size)) {
    846     TC3_LOG(ERROR) << "Cound not embed token's sparse features.";
    847     return false;
    848   }
    849 
    850   // If there is a cache, the embedded features for the token were not in it,
    851   // so insert them.
    852   if (embedding_cache) {
    853     (*embedding_cache)[{token.start, token.end}] = std::vector<float>(
    854         output_features_end - embedding_size, output_features_end);
    855   }
    856 
    857   // Append the dense features to the output.
    858   output_features->insert(output_features->end(), dense_features.begin(),
    859                           dense_features.end());
    860   return true;
    861 }
    862 
    863 }  // namespace libtextclassifier3
    864