Home | History | Annotate | Download | only in smartselect
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "smartselect/feature-processor.h"
     18 
     19 #include <iterator>
     20 #include <set>
     21 #include <vector>
     22 
     23 #include "smartselect/text-classification-model.pb.h"
     24 #include "util/base/logging.h"
     25 #include "util/strings/utf8.h"
     26 #include "util/utf8/unicodetext.h"
     27 #include "unicode/brkiter.h"
     28 #include "unicode/errorcode.h"
     29 #include "unicode/uchar.h"
     30 
     31 namespace libtextclassifier {
     32 
     33 namespace internal {
     34 
     35 TokenFeatureExtractorOptions BuildTokenFeatureExtractorOptions(
     36     const FeatureProcessorOptions& options) {
     37   TokenFeatureExtractorOptions extractor_options;
     38 
     39   extractor_options.num_buckets = options.num_buckets();
     40   for (int order : options.chargram_orders()) {
     41     extractor_options.chargram_orders.push_back(order);
     42   }
     43   extractor_options.max_word_length = options.max_word_length();
     44   extractor_options.extract_case_feature = options.extract_case_feature();
     45   extractor_options.unicode_aware_features = options.unicode_aware_features();
     46   extractor_options.extract_selection_mask_feature =
     47       options.extract_selection_mask_feature();
     48   for (int i = 0; i < options.regexp_feature_size(); ++i) {
     49     extractor_options.regexp_features.push_back(options.regexp_feature(i));
     50   }
     51   extractor_options.remap_digits = options.remap_digits();
     52   extractor_options.lowercase_tokens = options.lowercase_tokens();
     53 
     54   return extractor_options;
     55 }
     56 
     57 FeatureProcessorOptions ParseSerializedOptions(
     58     const std::string& serialized_options) {
     59   FeatureProcessorOptions options;
     60   options.ParseFromString(serialized_options);
     61   return options;
     62 }
     63 
     64 void SplitTokensOnSelectionBoundaries(CodepointSpan selection,
     65                                       std::vector<Token>* tokens) {
     66   for (auto it = tokens->begin(); it != tokens->end(); ++it) {
     67     const UnicodeText token_word =
     68         UTF8ToUnicodeText(it->value, /*do_copy=*/false);
     69 
     70     auto last_start = token_word.begin();
     71     int last_start_index = it->start;
     72     std::vector<UnicodeText::const_iterator> split_points;
     73 
     74     // Selection start split point.
     75     if (selection.first > it->start && selection.first < it->end) {
     76       std::advance(last_start, selection.first - last_start_index);
     77       split_points.push_back(last_start);
     78       last_start_index = selection.first;
     79     }
     80 
     81     // Selection end split point.
     82     if (selection.second > it->start && selection.second < it->end) {
     83       std::advance(last_start, selection.second - last_start_index);
     84       split_points.push_back(last_start);
     85     }
     86 
     87     if (!split_points.empty()) {
     88       // Add a final split for the rest of the token unless it's been all
     89       // consumed already.
     90       if (split_points.back() != token_word.end()) {
     91         split_points.push_back(token_word.end());
     92       }
     93 
     94       std::vector<Token> replacement_tokens;
     95       last_start = token_word.begin();
     96       int current_pos = it->start;
     97       for (const auto& split_point : split_points) {
     98         Token new_token(token_word.UTF8Substring(last_start, split_point),
     99                         current_pos,
    100                         current_pos + std::distance(last_start, split_point));
    101 
    102         last_start = split_point;
    103         current_pos = new_token.end;
    104 
    105         replacement_tokens.push_back(new_token);
    106       }
    107 
    108       it = tokens->erase(it);
    109       it = tokens->insert(it, replacement_tokens.begin(),
    110                           replacement_tokens.end());
    111       std::advance(it, replacement_tokens.size() - 1);
    112     }
    113   }
    114 }
    115 
    116 void FindSubstrings(const UnicodeText& t, const std::set<char32>& codepoints,
    117                     std::vector<UnicodeTextRange>* ranges) {
    118   UnicodeText::const_iterator start = t.begin();
    119   UnicodeText::const_iterator curr = start;
    120   UnicodeText::const_iterator end = t.end();
    121   for (; curr != end; ++curr) {
    122     if (codepoints.find(*curr) != codepoints.end()) {
    123       if (start != curr) {
    124         ranges->push_back(std::make_pair(start, curr));
    125       }
    126       start = curr;
    127       ++start;
    128     }
    129   }
    130   if (start != end) {
    131     ranges->push_back(std::make_pair(start, end));
    132   }
    133 }
    134 
    135 void StripTokensFromOtherLines(const std::string& context, CodepointSpan span,
    136                                std::vector<Token>* tokens) {
    137   const UnicodeText context_unicode = UTF8ToUnicodeText(context,
    138                                                         /*do_copy=*/false);
    139   std::vector<UnicodeTextRange> lines;
    140   std::set<char32> codepoints;
    141   codepoints.insert('\n');
    142   codepoints.insert('|');
    143   internal::FindSubstrings(context_unicode, codepoints, &lines);
    144 
    145   auto span_start = context_unicode.begin();
    146   if (span.first > 0) {
    147     std::advance(span_start, span.first);
    148   }
    149   auto span_end = context_unicode.begin();
    150   if (span.second > 0) {
    151     std::advance(span_end, span.second);
    152   }
    153   for (const UnicodeTextRange& line : lines) {
    154     // Find the line that completely contains the span.
    155     if (line.first <= span_start && line.second >= span_end) {
    156       const CodepointIndex last_line_begin_index =
    157           std::distance(context_unicode.begin(), line.first);
    158       const CodepointIndex last_line_end_index =
    159           last_line_begin_index + std::distance(line.first, line.second);
    160 
    161       for (auto token = tokens->begin(); token != tokens->end();) {
    162         if (token->start >= last_line_begin_index &&
    163             token->end <= last_line_end_index) {
    164           ++token;
    165         } else {
    166           token = tokens->erase(token);
    167         }
    168       }
    169     }
    170   }
    171 }
    172 
    173 }  // namespace internal
    174 
    175 std::string FeatureProcessor::GetDefaultCollection() const {
    176   if (options_.default_collection() >= options_.collections_size()) {
    177     TC_LOG(ERROR) << "No collections specified. Returning empty string.";
    178     return "";
    179   }
    180   return options_.collections(options_.default_collection());
    181 }
    182 
    183 std::vector<Token> FeatureProcessor::Tokenize(
    184     const std::string& utf8_text) const {
    185   if (options_.tokenization_type() ==
    186       libtextclassifier::FeatureProcessorOptions::INTERNAL_TOKENIZER) {
    187     return tokenizer_.Tokenize(utf8_text);
    188   } else if (options_.tokenization_type() ==
    189                  libtextclassifier::FeatureProcessorOptions::ICU ||
    190              options_.tokenization_type() ==
    191                  libtextclassifier::FeatureProcessorOptions::MIXED) {
    192     std::vector<Token> result;
    193     if (!ICUTokenize(utf8_text, &result)) {
    194       return {};
    195     }
    196     if (options_.tokenization_type() ==
    197         libtextclassifier::FeatureProcessorOptions::MIXED) {
    198       InternalRetokenize(utf8_text, &result);
    199     }
    200     return result;
    201   } else {
    202     TC_LOG(ERROR) << "Unknown tokenization type specified. Using "
    203                      "internal.";
    204     return tokenizer_.Tokenize(utf8_text);
    205   }
    206 }
    207 
    208 bool FeatureProcessor::LabelToSpan(
    209     const int label, const VectorSpan<Token>& tokens,
    210     std::pair<CodepointIndex, CodepointIndex>* span) const {
    211   if (tokens.size() != GetNumContextTokens()) {
    212     return false;
    213   }
    214 
    215   TokenSpan token_span;
    216   if (!LabelToTokenSpan(label, &token_span)) {
    217     return false;
    218   }
    219 
    220   const int result_begin_token = token_span.first;
    221   const int result_begin_codepoint =
    222       tokens[options_.context_size() - result_begin_token].start;
    223   const int result_end_token = token_span.second;
    224   const int result_end_codepoint =
    225       tokens[options_.context_size() + result_end_token].end;
    226 
    227   if (result_begin_codepoint == kInvalidIndex ||
    228       result_end_codepoint == kInvalidIndex) {
    229     *span = CodepointSpan({kInvalidIndex, kInvalidIndex});
    230   } else {
    231     *span = CodepointSpan({result_begin_codepoint, result_end_codepoint});
    232   }
    233   return true;
    234 }
    235 
    236 bool FeatureProcessor::LabelToTokenSpan(const int label,
    237                                         TokenSpan* token_span) const {
    238   if (label >= 0 && label < label_to_selection_.size()) {
    239     *token_span = label_to_selection_[label];
    240     return true;
    241   } else {
    242     return false;
    243   }
    244 }
    245 
    246 bool FeatureProcessor::SpanToLabel(
    247     const std::pair<CodepointIndex, CodepointIndex>& span,
    248     const std::vector<Token>& tokens, int* label) const {
    249   if (tokens.size() != GetNumContextTokens()) {
    250     return false;
    251   }
    252 
    253   const int click_position =
    254       options_.context_size();  // Click is always in the middle.
    255   const int padding = options_.context_size() - options_.max_selection_span();
    256 
    257   int span_left = 0;
    258   for (int i = click_position - 1; i >= padding; i--) {
    259     if (tokens[i].start != kInvalidIndex && tokens[i].end > span.first) {
    260       ++span_left;
    261     } else {
    262       break;
    263     }
    264   }
    265 
    266   int span_right = 0;
    267   for (int i = click_position + 1; i < tokens.size() - padding; ++i) {
    268     if (tokens[i].end != kInvalidIndex && tokens[i].start < span.second) {
    269       ++span_right;
    270     } else {
    271       break;
    272     }
    273   }
    274 
    275   // Check that the spanned tokens cover the whole span.
    276   bool tokens_match_span;
    277   if (options_.snap_label_span_boundaries_to_containing_tokens()) {
    278     tokens_match_span =
    279         tokens[click_position - span_left].start <= span.first &&
    280         tokens[click_position + span_right].end >= span.second;
    281   } else {
    282     tokens_match_span =
    283         tokens[click_position - span_left].start == span.first &&
    284         tokens[click_position + span_right].end == span.second;
    285   }
    286 
    287   if (tokens_match_span) {
    288     *label = TokenSpanToLabel({span_left, span_right});
    289   } else {
    290     *label = kInvalidLabel;
    291   }
    292 
    293   return true;
    294 }
    295 
    296 int FeatureProcessor::TokenSpanToLabel(const TokenSpan& span) const {
    297   auto it = selection_to_label_.find(span);
    298   if (it != selection_to_label_.end()) {
    299     return it->second;
    300   } else {
    301     return kInvalidLabel;
    302   }
    303 }
    304 
    305 TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
    306                                    CodepointSpan codepoint_span) {
    307   const int codepoint_start = std::get<0>(codepoint_span);
    308   const int codepoint_end = std::get<1>(codepoint_span);
    309 
    310   TokenIndex start_token = kInvalidIndex;
    311   TokenIndex end_token = kInvalidIndex;
    312   for (int i = 0; i < selectable_tokens.size(); ++i) {
    313     if (codepoint_start <= selectable_tokens[i].start &&
    314         codepoint_end >= selectable_tokens[i].end &&
    315         !selectable_tokens[i].is_padding) {
    316       if (start_token == kInvalidIndex) {
    317         start_token = i;
    318       }
    319       end_token = i + 1;
    320     }
    321   }
    322   return {start_token, end_token};
    323 }
    324 
    325 CodepointSpan TokenSpanToCodepointSpan(
    326     const std::vector<Token>& selectable_tokens, TokenSpan token_span) {
    327   return {selectable_tokens[token_span.first].start,
    328           selectable_tokens[token_span.second - 1].end};
    329 }
    330 
    331 namespace {
    332 
    333 // Finds a single token that completely contains the given span.
    334 int FindTokenThatContainsSpan(const std::vector<Token>& selectable_tokens,
    335                               CodepointSpan codepoint_span) {
    336   const int codepoint_start = std::get<0>(codepoint_span);
    337   const int codepoint_end = std::get<1>(codepoint_span);
    338 
    339   for (int i = 0; i < selectable_tokens.size(); ++i) {
    340     if (codepoint_start >= selectable_tokens[i].start &&
    341         codepoint_end <= selectable_tokens[i].end) {
    342       return i;
    343     }
    344   }
    345   return kInvalidIndex;
    346 }
    347 
    348 }  // namespace
    349 
    350 namespace internal {
    351 
    352 int CenterTokenFromClick(CodepointSpan span,
    353                          const std::vector<Token>& selectable_tokens) {
    354   int range_begin;
    355   int range_end;
    356   std::tie(range_begin, range_end) =
    357       CodepointSpanToTokenSpan(selectable_tokens, span);
    358 
    359   // If no exact match was found, try finding a token that completely contains
    360   // the click span. This is useful e.g. when Android builds the selection
    361   // using ICU tokenization, and ends up with only a portion of our space-
    362   // separated token. E.g. for "(857)" Android would select "857".
    363   if (range_begin == kInvalidIndex || range_end == kInvalidIndex) {
    364     int token_index = FindTokenThatContainsSpan(selectable_tokens, span);
    365     if (token_index != kInvalidIndex) {
    366       range_begin = token_index;
    367       range_end = token_index + 1;
    368     }
    369   }
    370 
    371   // We only allow clicks that are exactly 1 selectable token.
    372   if (range_end - range_begin == 1) {
    373     return range_begin;
    374   } else {
    375     return kInvalidIndex;
    376   }
    377 }
    378 
    379 int CenterTokenFromMiddleOfSelection(
    380     CodepointSpan span, const std::vector<Token>& selectable_tokens) {
    381   int range_begin;
    382   int range_end;
    383   std::tie(range_begin, range_end) =
    384       CodepointSpanToTokenSpan(selectable_tokens, span);
    385 
    386   // Center the clicked token in the selection range.
    387   if (range_begin != kInvalidIndex && range_end != kInvalidIndex) {
    388     return (range_begin + range_end - 1) / 2;
    389   } else {
    390     return kInvalidIndex;
    391   }
    392 }
    393 
    394 }  // namespace internal
    395 
    396 int FeatureProcessor::FindCenterToken(CodepointSpan span,
    397                                       const std::vector<Token>& tokens) const {
    398   if (options_.center_token_selection_method() ==
    399       FeatureProcessorOptions::CENTER_TOKEN_FROM_CLICK) {
    400     return internal::CenterTokenFromClick(span, tokens);
    401   } else if (options_.center_token_selection_method() ==
    402              FeatureProcessorOptions::CENTER_TOKEN_MIDDLE_OF_SELECTION) {
    403     return internal::CenterTokenFromMiddleOfSelection(span, tokens);
    404   } else if (options_.center_token_selection_method() ==
    405              FeatureProcessorOptions::DEFAULT_CENTER_TOKEN_METHOD) {
    406     // TODO(zilka): Remove once we have new models on the device.
    407     // It uses the fact that sharing model use
    408     // split_tokens_on_selection_boundaries and selection not. So depending on
    409     // this we select the right way of finding the click location.
    410     if (!options_.split_tokens_on_selection_boundaries()) {
    411       // SmartSelection model.
    412       return internal::CenterTokenFromClick(span, tokens);
    413     } else {
    414       // SmartSharing model.
    415       return internal::CenterTokenFromMiddleOfSelection(span, tokens);
    416     }
    417   } else {
    418     TC_LOG(ERROR) << "Invalid center token selection method.";
    419     return kInvalidIndex;
    420   }
    421 }
    422 
    423 bool FeatureProcessor::SelectionLabelSpans(
    424     const VectorSpan<Token> tokens,
    425     std::vector<CodepointSpan>* selection_label_spans) const {
    426   for (int i = 0; i < label_to_selection_.size(); ++i) {
    427     CodepointSpan span;
    428     if (!LabelToSpan(i, tokens, &span)) {
    429       TC_LOG(ERROR) << "Could not convert label to span: " << i;
    430       return false;
    431     }
    432     selection_label_spans->push_back(span);
    433   }
    434   return true;
    435 }
    436 
    437 void FeatureProcessor::PrepareCodepointRanges(
    438     const std::vector<FeatureProcessorOptions::CodepointRange>&
    439         codepoint_ranges,
    440     std::vector<CodepointRange>* prepared_codepoint_ranges) {
    441   prepared_codepoint_ranges->clear();
    442   prepared_codepoint_ranges->reserve(codepoint_ranges.size());
    443   for (const FeatureProcessorOptions::CodepointRange& range :
    444        codepoint_ranges) {
    445     prepared_codepoint_ranges->push_back(
    446         CodepointRange(range.start(), range.end()));
    447   }
    448 
    449   std::sort(prepared_codepoint_ranges->begin(),
    450             prepared_codepoint_ranges->end(),
    451             [](const CodepointRange& a, const CodepointRange& b) {
    452               return a.start < b.start;
    453             });
    454 }
    455 
    456 float FeatureProcessor::SupportedCodepointsRatio(
    457     int click_pos, const std::vector<Token>& tokens) const {
    458   int num_supported = 0;
    459   int num_total = 0;
    460   for (int i = click_pos - options_.context_size();
    461        i <= click_pos + options_.context_size(); ++i) {
    462     const bool is_valid_token = i >= 0 && i < tokens.size();
    463     if (is_valid_token) {
    464       const UnicodeText value =
    465           UTF8ToUnicodeText(tokens[i].value, /*do_copy=*/false);
    466       for (auto codepoint : value) {
    467         if (IsCodepointInRanges(codepoint, supported_codepoint_ranges_)) {
    468           ++num_supported;
    469         }
    470         ++num_total;
    471       }
    472     }
    473   }
    474   return static_cast<float>(num_supported) / static_cast<float>(num_total);
    475 }
    476 
    477 bool FeatureProcessor::IsCodepointInRanges(
    478     int codepoint, const std::vector<CodepointRange>& codepoint_ranges) const {
    479   auto it = std::lower_bound(codepoint_ranges.begin(), codepoint_ranges.end(),
    480                              codepoint,
    481                              [](const CodepointRange& range, int codepoint) {
    482                                // This function compares range with the
    483                                // codepoint for the purpose of finding the first
    484                                // greater or equal range. Because of the use of
    485                                // std::lower_bound it needs to return true when
    486                                // range < codepoint; the first time it will
    487                                // return false the lower bound is found and
    488                                // returned.
    489                                //
    490                                // It might seem weird that the condition is
    491                                // range.end <= codepoint here but when codepoint
    492                                // == range.end it means it's actually just
    493                                // outside of the range, thus the range is less
    494                                // than the codepoint.
    495                                return range.end <= codepoint;
    496                              });
    497   if (it != codepoint_ranges.end() && it->start <= codepoint &&
    498       it->end > codepoint) {
    499     return true;
    500   } else {
    501     return false;
    502   }
    503 }
    504 
    505 int FeatureProcessor::CollectionToLabel(const std::string& collection) const {
    506   const auto it = collection_to_label_.find(collection);
    507   if (it == collection_to_label_.end()) {
    508     return options_.default_collection();
    509   } else {
    510     return it->second;
    511   }
    512 }
    513 
    514 std::string FeatureProcessor::LabelToCollection(int label) const {
    515   if (label >= 0 && label < collection_to_label_.size()) {
    516     return options_.collections(label);
    517   } else {
    518     return GetDefaultCollection();
    519   }
    520 }
    521 
    522 void FeatureProcessor::MakeLabelMaps() {
    523   for (int i = 0; i < options_.collections().size(); ++i) {
    524     collection_to_label_[options_.collections(i)] = i;
    525   }
    526 
    527   int selection_label_id = 0;
    528   for (int l = 0; l < (options_.max_selection_span() + 1); ++l) {
    529     for (int r = 0; r < (options_.max_selection_span() + 1); ++r) {
    530       if (!options_.selection_reduced_output_space() ||
    531           r + l <= options_.max_selection_span()) {
    532         TokenSpan token_span{l, r};
    533         selection_to_label_[token_span] = selection_label_id;
    534         label_to_selection_.push_back(token_span);
    535         ++selection_label_id;
    536       }
    537     }
    538   }
    539 }
    540 
    541 void FeatureProcessor::TokenizeAndFindClick(const std::string& context,
    542                                             CodepointSpan input_span,
    543                                             std::vector<Token>* tokens,
    544                                             int* click_pos) const {
    545   TC_CHECK(tokens != nullptr);
    546   *tokens = Tokenize(context);
    547 
    548   if (options_.split_tokens_on_selection_boundaries()) {
    549     internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
    550   }
    551 
    552   if (options_.only_use_line_with_click()) {
    553     internal::StripTokensFromOtherLines(context, input_span, tokens);
    554   }
    555 
    556   int local_click_pos;
    557   if (click_pos == nullptr) {
    558     click_pos = &local_click_pos;
    559   }
    560   *click_pos = FindCenterToken(input_span, *tokens);
    561 }
    562 
    563 namespace internal {
    564 
    565 void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
    566                       std::vector<Token>* tokens, int* click_pos) {
    567   int right_context_needed = relative_click_span.second + context_size;
    568   if (*click_pos + right_context_needed + 1 >= tokens->size()) {
    569     // Pad max the context size.
    570     const int num_pad_tokens = std::min(
    571         context_size, static_cast<int>(*click_pos + right_context_needed + 1 -
    572                                        tokens->size()));
    573     std::vector<Token> pad_tokens(num_pad_tokens);
    574     tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
    575   } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) {
    576     // Strip unused tokens.
    577     auto it = tokens->begin();
    578     std::advance(it, *click_pos + right_context_needed + 1);
    579     tokens->erase(it, tokens->end());
    580   }
    581 
    582   int left_context_needed = relative_click_span.first + context_size;
    583   if (*click_pos < left_context_needed) {
    584     // Pad max the context size.
    585     const int num_pad_tokens =
    586         std::min(context_size, left_context_needed - *click_pos);
    587     std::vector<Token> pad_tokens(num_pad_tokens);
    588     tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
    589     *click_pos += num_pad_tokens;
    590   } else if (*click_pos > left_context_needed) {
    591     // Strip unused tokens.
    592     auto it = tokens->begin();
    593     std::advance(it, *click_pos - left_context_needed);
    594     *click_pos -= it - tokens->begin();
    595     tokens->erase(tokens->begin(), it);
    596   }
    597 }
    598 
    599 }  // namespace internal
    600 
    601 bool FeatureProcessor::ExtractFeatures(
    602     const std::string& context, CodepointSpan input_span,
    603     TokenSpan relative_click_span, const FeatureVectorFn& feature_vector_fn,
    604     int feature_vector_size, std::vector<Token>* tokens, int* click_pos,
    605     std::unique_ptr<CachedFeatures>* cached_features) const {
    606   TokenizeAndFindClick(context, input_span, tokens, click_pos);
    607 
    608   // If the default click method failed, let's try to do sub-token matching
    609   // before we fail.
    610   if (*click_pos == kInvalidIndex) {
    611     *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
    612     if (*click_pos == kInvalidIndex) {
    613       return false;
    614     }
    615   }
    616 
    617   internal::StripOrPadTokens(relative_click_span, options_.context_size(),
    618                              tokens, click_pos);
    619 
    620   if (options_.min_supported_codepoint_ratio() > 0) {
    621     const float supported_codepoint_ratio =
    622         SupportedCodepointsRatio(*click_pos, *tokens);
    623     if (supported_codepoint_ratio < options_.min_supported_codepoint_ratio()) {
    624       TC_LOG(INFO) << "Not enough supported codepoints in the context: "
    625                    << supported_codepoint_ratio;
    626       return false;
    627     }
    628   }
    629 
    630   std::vector<std::vector<int>> sparse_features(tokens->size());
    631   std::vector<std::vector<float>> dense_features(tokens->size());
    632   for (int i = 0; i < tokens->size(); ++i) {
    633     const Token& token = (*tokens)[i];
    634     if (!feature_extractor_.Extract(token, token.IsContainedInSpan(input_span),
    635                                     &(sparse_features[i]),
    636                                     &(dense_features[i]))) {
    637       TC_LOG(ERROR) << "Could not extract token's features: " << token;
    638       return false;
    639     }
    640   }
    641 
    642   cached_features->reset(new CachedFeatures(
    643       *tokens, options_.context_size(), sparse_features, dense_features,
    644       feature_vector_fn, feature_vector_size));
    645 
    646   if (*cached_features == nullptr) {
    647     return false;
    648   }
    649 
    650   if (options_.feature_version() == 0) {
    651     (*cached_features)
    652         ->SetV0FeatureMode(feature_vector_size -
    653                            feature_extractor_.DenseFeaturesCount());
    654   }
    655 
    656   return true;
    657 }
    658 
    659 bool FeatureProcessor::ICUTokenize(const std::string& context,
    660                                    std::vector<Token>* result) const {
    661   icu::ErrorCode status;
    662   icu::UnicodeString unicode_text = icu::UnicodeString::fromUTF8(context);
    663   std::unique_ptr<icu::BreakIterator> break_iterator(
    664       icu::BreakIterator::createWordInstance(icu::Locale("en"), status));
    665   if (!status.isSuccess()) {
    666     TC_LOG(ERROR) << "Break iterator did not initialize properly: "
    667                   << status.errorName();
    668     return false;
    669   }
    670 
    671   break_iterator->setText(unicode_text);
    672 
    673   size_t last_break_index = 0;
    674   size_t break_index = 0;
    675   size_t last_unicode_index = 0;
    676   size_t unicode_index = 0;
    677   while ((break_index = break_iterator->next()) != icu::BreakIterator::DONE) {
    678     icu::UnicodeString token(unicode_text, last_break_index,
    679                              break_index - last_break_index);
    680     int token_length = token.countChar32();
    681     unicode_index = last_unicode_index + token_length;
    682 
    683     std::string token_utf8;
    684     token.toUTF8String(token_utf8);
    685 
    686     bool is_whitespace = true;
    687     for (int i = 0; i < token.length(); i++) {
    688       if (!u_isWhitespace(token.char32At(i))) {
    689         is_whitespace = false;
    690       }
    691     }
    692 
    693     if (!is_whitespace || options_.icu_preserve_whitespace_tokens()) {
    694       result->push_back(Token(token_utf8, last_unicode_index, unicode_index));
    695     }
    696 
    697     last_break_index = break_index;
    698     last_unicode_index = unicode_index;
    699   }
    700 
    701   return true;
    702 }
    703 
    704 void FeatureProcessor::InternalRetokenize(const std::string& context,
    705                                           std::vector<Token>* tokens) const {
    706   const UnicodeText unicode_text =
    707       UTF8ToUnicodeText(context, /*do_copy=*/false);
    708 
    709   std::vector<Token> result;
    710   CodepointSpan span(-1, -1);
    711   for (Token& token : *tokens) {
    712     const UnicodeText unicode_token_value =
    713         UTF8ToUnicodeText(token.value, /*do_copy=*/false);
    714     bool should_retokenize = true;
    715     for (const int codepoint : unicode_token_value) {
    716       if (!IsCodepointInRanges(codepoint,
    717                                internal_tokenizer_codepoint_ranges_)) {
    718         should_retokenize = false;
    719         break;
    720       }
    721     }
    722 
    723     if (should_retokenize) {
    724       if (span.first < 0) {
    725         span.first = token.start;
    726       }
    727       span.second = token.end;
    728     } else {
    729       TokenizeSubstring(unicode_text, span, &result);
    730       span.first = -1;
    731       result.emplace_back(std::move(token));
    732     }
    733   }
    734   TokenizeSubstring(unicode_text, span, &result);
    735 
    736   *tokens = std::move(result);
    737 }
    738 
    739 void FeatureProcessor::TokenizeSubstring(const UnicodeText& unicode_text,
    740                                          CodepointSpan span,
    741                                          std::vector<Token>* result) const {
    742   if (span.first < 0) {
    743     // There is no span to tokenize.
    744     return;
    745   }
    746 
    747   // Extract the substring.
    748   UnicodeText::const_iterator it_begin = unicode_text.begin();
    749   for (int i = 0; i < span.first; ++i) {
    750     ++it_begin;
    751   }
    752   UnicodeText::const_iterator it_end = unicode_text.begin();
    753   for (int i = 0; i < span.second; ++i) {
    754     ++it_end;
    755   }
    756   const std::string text = unicode_text.UTF8Substring(it_begin, it_end);
    757 
    758   // Run the tokenizer and update the token bounds to reflect the offset of the
    759   // substring.
    760   std::vector<Token> tokens = tokenizer_.Tokenize(text);
    761   for (Token& token : tokens) {
    762     token.start += span.first;
    763     token.end += span.first;
    764     result->emplace_back(std::move(token));
    765   }
    766 }
    767 
    768 }  // namespace libtextclassifier
    769