Home | History | Annotate | Download | only in smartselect
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "smartselect/token-feature-extractor.h"
     18 
     19 #include <string>
     20 
     21 #include "util/base/logging.h"
     22 #include "util/hash/farmhash.h"
     23 #include "util/strings/stringpiece.h"
     24 #include "util/utf8/unicodetext.h"
     25 #include "unicode/regex.h"
     26 #include "unicode/uchar.h"
     27 
     28 namespace libtextclassifier {
     29 
     30 namespace {
     31 
     32 std::string RemapTokenAscii(const std::string& token,
     33                             const TokenFeatureExtractorOptions& options) {
     34   if (!options.remap_digits && !options.lowercase_tokens) {
     35     return token;
     36   }
     37 
     38   std::string copy = token;
     39   for (int i = 0; i < token.size(); ++i) {
     40     if (options.remap_digits && isdigit(copy[i])) {
     41       copy[i] = '0';
     42     }
     43     if (options.lowercase_tokens) {
     44       copy[i] = tolower(copy[i]);
     45     }
     46   }
     47   return copy;
     48 }
     49 
     50 void RemapTokenUnicode(const std::string& token,
     51                        const TokenFeatureExtractorOptions& options,
     52                        UnicodeText* remapped) {
     53   if (!options.remap_digits && !options.lowercase_tokens) {
     54     // Leave remapped untouched.
     55     return;
     56   }
     57 
     58   UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false);
     59   icu::UnicodeString icu_string;
     60   for (auto it = word.begin(); it != word.end(); ++it) {
     61     if (options.remap_digits && u_isdigit(*it)) {
     62       icu_string.append('0');
     63     } else if (options.lowercase_tokens) {
     64       icu_string.append(u_tolower(*it));
     65     } else {
     66       icu_string.append(*it);
     67     }
     68   }
     69   std::string utf8_str;
     70   icu_string.toUTF8String(utf8_str);
     71   remapped->CopyUTF8(utf8_str.data(), utf8_str.length());
     72 }
     73 
     74 }  // namespace
     75 
     76 TokenFeatureExtractor::TokenFeatureExtractor(
     77     const TokenFeatureExtractorOptions& options)
     78     : options_(options) {
     79   UErrorCode status;
     80   for (const std::string& pattern : options.regexp_features) {
     81     status = U_ZERO_ERROR;
     82     regex_patterns_.push_back(
     83         std::unique_ptr<icu::RegexPattern>(icu::RegexPattern::compile(
     84             icu::UnicodeString(pattern.c_str(), pattern.size(), "utf-8"), 0,
     85             status)));
     86     if (U_FAILURE(status)) {
     87       TC_LOG(WARNING) << "Failed to load pattern" << pattern;
     88     }
     89   }
     90 }
     91 
     92 int TokenFeatureExtractor::HashToken(StringPiece token) const {
     93   return tcfarmhash::Fingerprint64(token) % options_.num_buckets;
     94 }
     95 
     96 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
     97     const Token& token) const {
     98   if (options_.unicode_aware_features) {
     99     return ExtractCharactergramFeaturesUnicode(token);
    100   } else {
    101     return ExtractCharactergramFeaturesAscii(token);
    102   }
    103 }
    104 
    105 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii(
    106     const Token& token) const {
    107   std::vector<int> result;
    108   if (token.is_padding || token.value.empty()) {
    109     result.push_back(HashToken("<PAD>"));
    110   } else {
    111     const std::string word = RemapTokenAscii(token.value, options_);
    112 
    113     // Trim words that are over max_word_length characters.
    114     const int max_word_length = options_.max_word_length;
    115     std::string feature_word;
    116     if (word.size() > max_word_length) {
    117       feature_word =
    118           "^" + word.substr(0, max_word_length / 2) + "\1" +
    119           word.substr(word.size() - max_word_length / 2, max_word_length / 2) +
    120           "$";
    121     } else {
    122       // Add a prefix and suffix to the word.
    123       feature_word = "^" + word + "$";
    124     }
    125 
    126     // Upper-bound the number of charactergram extracted to avoid resizing.
    127     result.reserve(options_.chargram_orders.size() * feature_word.size());
    128 
    129     // Generate the character-grams.
    130     for (int chargram_order : options_.chargram_orders) {
    131       if (chargram_order == 1) {
    132         for (int i = 1; i < feature_word.size() - 1; ++i) {
    133           result.push_back(
    134               HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
    135         }
    136       } else {
    137         for (int i = 0;
    138              i < static_cast<int>(feature_word.size()) - chargram_order + 1;
    139              ++i) {
    140           result.push_back(HashToken(
    141               StringPiece(feature_word, /*offset=*/i, /*len=*/chargram_order)));
    142         }
    143       }
    144     }
    145   }
    146   return result;
    147 }
    148 
    149 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode(
    150     const Token& token) const {
    151   std::vector<int> result;
    152   if (token.is_padding || token.value.empty()) {
    153     result.push_back(HashToken("<PAD>"));
    154   } else {
    155     UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false);
    156     RemapTokenUnicode(token.value, options_, &word);
    157 
    158     // Trim the word if needed by finding a left-cut point and right-cut point.
    159     auto left_cut = word.begin();
    160     auto right_cut = word.end();
    161     for (int i = 0; i < options_.max_word_length / 2; i++) {
    162       if (left_cut < right_cut) {
    163         ++left_cut;
    164       }
    165       if (left_cut < right_cut) {
    166         --right_cut;
    167       }
    168     }
    169 
    170     std::string feature_word;
    171     if (left_cut == right_cut) {
    172       feature_word = "^" + word.UTF8Substring(word.begin(), word.end()) + "$";
    173     } else {
    174       // clang-format off
    175       feature_word = "^" +
    176                      word.UTF8Substring(word.begin(), left_cut) +
    177                      "\1" +
    178                      word.UTF8Substring(right_cut, word.end()) +
    179                      "$";
    180       // clang-format on
    181     }
    182 
    183     const UnicodeText feature_word_unicode =
    184         UTF8ToUnicodeText(feature_word, /*do_copy=*/false);
    185 
    186     // Upper-bound the number of charactergram extracted to avoid resizing.
    187     result.reserve(options_.chargram_orders.size() * feature_word.size());
    188 
    189     // Generate the character-grams.
    190     for (int chargram_order : options_.chargram_orders) {
    191       UnicodeText::const_iterator it_start = feature_word_unicode.begin();
    192       UnicodeText::const_iterator it_end = feature_word_unicode.end();
    193       if (chargram_order == 1) {
    194         ++it_start;
    195         --it_end;
    196       }
    197 
    198       UnicodeText::const_iterator it_chargram_start = it_start;
    199       UnicodeText::const_iterator it_chargram_end = it_start;
    200       bool chargram_is_complete = true;
    201       for (int i = 0; i < chargram_order; ++i) {
    202         if (it_chargram_end == it_end) {
    203           chargram_is_complete = false;
    204           break;
    205         }
    206         ++it_chargram_end;
    207       }
    208       if (!chargram_is_complete) {
    209         continue;
    210       }
    211 
    212       for (; it_chargram_end <= it_end;
    213            ++it_chargram_start, ++it_chargram_end) {
    214         const int length_bytes =
    215             it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
    216         result.push_back(HashToken(
    217             StringPiece(it_chargram_start.utf8_data(), length_bytes)));
    218       }
    219     }
    220   }
    221   return result;
    222 }
    223 
    224 bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
    225                                     std::vector<int>* sparse_features,
    226                                     std::vector<float>* dense_features) const {
    227   if (sparse_features == nullptr || dense_features == nullptr) {
    228     return false;
    229   }
    230 
    231   *sparse_features = ExtractCharactergramFeatures(token);
    232 
    233   if (options_.extract_case_feature) {
    234     if (options_.unicode_aware_features) {
    235       UnicodeText token_unicode =
    236           UTF8ToUnicodeText(token.value, /*do_copy=*/false);
    237       if (!token.value.empty() && u_isupper(*token_unicode.begin())) {
    238         dense_features->push_back(1.0);
    239       } else {
    240         dense_features->push_back(-1.0);
    241       }
    242     } else {
    243       if (!token.value.empty() && isupper(*token.value.begin())) {
    244         dense_features->push_back(1.0);
    245       } else {
    246         dense_features->push_back(-1.0);
    247       }
    248     }
    249   }
    250 
    251   if (options_.extract_selection_mask_feature) {
    252     if (is_in_span) {
    253       dense_features->push_back(1.0);
    254     } else {
    255       if (options_.unicode_aware_features) {
    256         dense_features->push_back(-1.0);
    257       } else {
    258         dense_features->push_back(0.0);
    259       }
    260     }
    261   }
    262 
    263   // Add regexp features.
    264   if (!regex_patterns_.empty()) {
    265     icu::UnicodeString unicode_str(token.value.c_str(), token.value.size(),
    266                                    "utf-8");
    267     for (int i = 0; i < regex_patterns_.size(); ++i) {
    268       if (!regex_patterns_[i].get()) {
    269         dense_features->push_back(-1.0);
    270         continue;
    271       }
    272 
    273       // Check for match.
    274       UErrorCode status = U_ZERO_ERROR;
    275       std::unique_ptr<icu::RegexMatcher> matcher(
    276           regex_patterns_[i]->matcher(unicode_str, status));
    277       if (matcher->find()) {
    278         dense_features->push_back(1.0);
    279       } else {
    280         dense_features->push_back(-1.0);
    281       }
    282     }
    283   }
    284   return true;
    285 }
    286 
    287 }  // namespace libtextclassifier
    288