Home | History | Annotate | Download | only in libtextclassifier
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "token-feature-extractor.h"
     18 
     19 #include <cctype>
     20 #include <string>
     21 
     22 #include "util/base/logging.h"
     23 #include "util/hash/farmhash.h"
     24 #include "util/strings/stringpiece.h"
     25 #include "util/utf8/unicodetext.h"
     26 
     27 namespace libtextclassifier2 {
     28 
     29 namespace {
     30 
     31 std::string RemapTokenAscii(const std::string& token,
     32                             const TokenFeatureExtractorOptions& options) {
     33   if (!options.remap_digits && !options.lowercase_tokens) {
     34     return token;
     35   }
     36 
     37   std::string copy = token;
     38   for (int i = 0; i < token.size(); ++i) {
     39     if (options.remap_digits && isdigit(copy[i])) {
     40       copy[i] = '0';
     41     }
     42     if (options.lowercase_tokens) {
     43       copy[i] = tolower(copy[i]);
     44     }
     45   }
     46   return copy;
     47 }
     48 
     49 void RemapTokenUnicode(const std::string& token,
     50                        const TokenFeatureExtractorOptions& options,
     51                        const UniLib& unilib, UnicodeText* remapped) {
     52   if (!options.remap_digits && !options.lowercase_tokens) {
     53     // Leave remapped untouched.
     54     return;
     55   }
     56 
     57   UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false);
     58   remapped->clear();
     59   for (auto it = word.begin(); it != word.end(); ++it) {
     60     if (options.remap_digits && unilib.IsDigit(*it)) {
     61       remapped->AppendCodepoint('0');
     62     } else if (options.lowercase_tokens) {
     63       remapped->AppendCodepoint(unilib.ToLower(*it));
     64     } else {
     65       remapped->AppendCodepoint(*it);
     66     }
     67   }
     68 }
     69 
     70 }  // namespace
     71 
     72 TokenFeatureExtractor::TokenFeatureExtractor(
     73     const TokenFeatureExtractorOptions& options, const UniLib& unilib)
     74     : options_(options), unilib_(unilib) {
     75   for (const std::string& pattern : options.regexp_features) {
     76     regex_patterns_.push_back(std::unique_ptr<UniLib::RegexPattern>(
     77         unilib_.CreateRegexPattern(UTF8ToUnicodeText(
     78             pattern.c_str(), pattern.size(), /*do_copy=*/false))));
     79   }
     80 }
     81 
     82 bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
     83                                     std::vector<int>* sparse_features,
     84                                     std::vector<float>* dense_features) const {
     85   if (!dense_features) {
     86     return false;
     87   }
     88   if (sparse_features) {
     89     *sparse_features = ExtractCharactergramFeatures(token);
     90   }
     91   *dense_features = ExtractDenseFeatures(token, is_in_span);
     92   return true;
     93 }
     94 
     95 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures(
     96     const Token& token) const {
     97   if (options_.unicode_aware_features) {
     98     return ExtractCharactergramFeaturesUnicode(token);
     99   } else {
    100     return ExtractCharactergramFeaturesAscii(token);
    101   }
    102 }
    103 
    104 std::vector<float> TokenFeatureExtractor::ExtractDenseFeatures(
    105     const Token& token, bool is_in_span) const {
    106   std::vector<float> dense_features;
    107 
    108   if (options_.extract_case_feature) {
    109     if (options_.unicode_aware_features) {
    110       UnicodeText token_unicode =
    111           UTF8ToUnicodeText(token.value, /*do_copy=*/false);
    112       const bool is_upper = unilib_.IsUpper(*token_unicode.begin());
    113       if (!token.value.empty() && is_upper) {
    114         dense_features.push_back(1.0);
    115       } else {
    116         dense_features.push_back(-1.0);
    117       }
    118     } else {
    119       if (!token.value.empty() && isupper(*token.value.begin())) {
    120         dense_features.push_back(1.0);
    121       } else {
    122         dense_features.push_back(-1.0);
    123       }
    124     }
    125   }
    126 
    127   if (options_.extract_selection_mask_feature) {
    128     if (is_in_span) {
    129       dense_features.push_back(1.0);
    130     } else {
    131       if (options_.unicode_aware_features) {
    132         dense_features.push_back(-1.0);
    133       } else {
    134         dense_features.push_back(0.0);
    135       }
    136     }
    137   }
    138 
    139   // Add regexp features.
    140   if (!regex_patterns_.empty()) {
    141     UnicodeText token_unicode =
    142         UTF8ToUnicodeText(token.value, /*do_copy=*/false);
    143     for (int i = 0; i < regex_patterns_.size(); ++i) {
    144       if (!regex_patterns_[i].get()) {
    145         dense_features.push_back(-1.0);
    146         continue;
    147       }
    148       auto matcher = regex_patterns_[i]->Matcher(token_unicode);
    149       int status;
    150       if (matcher->Matches(&status)) {
    151         dense_features.push_back(1.0);
    152       } else {
    153         dense_features.push_back(-1.0);
    154       }
    155     }
    156   }
    157 
    158   return dense_features;
    159 }
    160 
    161 int TokenFeatureExtractor::HashToken(StringPiece token) const {
    162   if (options_.allowed_chargrams.empty()) {
    163     return tc2farmhash::Fingerprint64(token) % options_.num_buckets;
    164   } else {
    165     // Padding and out-of-vocabulary tokens have extra buckets reserved because
    166     // they are special and important tokens, and we don't want them to share
    167     // embedding with other charactergrams.
    168     // TODO(zilka): Experimentally verify.
    169     const int kNumExtraBuckets = 2;
    170     const std::string token_string = token.ToString();
    171     if (token_string == "<PAD>") {
    172       return 1;
    173     } else if (options_.allowed_chargrams.find(token_string) ==
    174                options_.allowed_chargrams.end()) {
    175       return 0;  // Out-of-vocabulary.
    176     } else {
    177       return (tc2farmhash::Fingerprint64(token) %
    178               (options_.num_buckets - kNumExtraBuckets)) +
    179              kNumExtraBuckets;
    180     }
    181   }
    182 }
    183 
    184 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii(
    185     const Token& token) const {
    186   std::vector<int> result;
    187   if (token.is_padding || token.value.empty()) {
    188     result.push_back(HashToken("<PAD>"));
    189   } else {
    190     const std::string word = RemapTokenAscii(token.value, options_);
    191 
    192     // Trim words that are over max_word_length characters.
    193     const int max_word_length = options_.max_word_length;
    194     std::string feature_word;
    195     if (word.size() > max_word_length) {
    196       feature_word =
    197           "^" + word.substr(0, max_word_length / 2) + "\1" +
    198           word.substr(word.size() - max_word_length / 2, max_word_length / 2) +
    199           "$";
    200     } else {
    201       // Add a prefix and suffix to the word.
    202       feature_word = "^" + word + "$";
    203     }
    204 
    205     // Upper-bound the number of charactergram extracted to avoid resizing.
    206     result.reserve(options_.chargram_orders.size() * feature_word.size());
    207 
    208     if (options_.chargram_orders.empty()) {
    209       result.push_back(HashToken(feature_word));
    210     } else {
    211       // Generate the character-grams.
    212       for (int chargram_order : options_.chargram_orders) {
    213         if (chargram_order == 1) {
    214           for (int i = 1; i < feature_word.size() - 1; ++i) {
    215             result.push_back(
    216                 HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1)));
    217           }
    218         } else {
    219           for (int i = 0;
    220                i < static_cast<int>(feature_word.size()) - chargram_order + 1;
    221                ++i) {
    222             result.push_back(HashToken(StringPiece(feature_word, /*offset=*/i,
    223                                                    /*len=*/chargram_order)));
    224           }
    225         }
    226       }
    227     }
    228   }
    229   return result;
    230 }
    231 
    232 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode(
    233     const Token& token) const {
    234   std::vector<int> result;
    235   if (token.is_padding || token.value.empty()) {
    236     result.push_back(HashToken("<PAD>"));
    237   } else {
    238     UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false);
    239     RemapTokenUnicode(token.value, options_, unilib_, &word);
    240 
    241     // Trim the word if needed by finding a left-cut point and right-cut point.
    242     auto left_cut = word.begin();
    243     auto right_cut = word.end();
    244     for (int i = 0; i < options_.max_word_length / 2; i++) {
    245       if (left_cut < right_cut) {
    246         ++left_cut;
    247       }
    248       if (left_cut < right_cut) {
    249         --right_cut;
    250       }
    251     }
    252 
    253     std::string feature_word;
    254     if (left_cut == right_cut) {
    255       feature_word = "^" + word.UTF8Substring(word.begin(), word.end()) + "$";
    256     } else {
    257       // clang-format off
    258       feature_word = "^" +
    259                      word.UTF8Substring(word.begin(), left_cut) +
    260                      "\1" +
    261                      word.UTF8Substring(right_cut, word.end()) +
    262                      "$";
    263       // clang-format on
    264     }
    265 
    266     const UnicodeText feature_word_unicode =
    267         UTF8ToUnicodeText(feature_word, /*do_copy=*/false);
    268 
    269     // Upper-bound the number of charactergram extracted to avoid resizing.
    270     result.reserve(options_.chargram_orders.size() * feature_word.size());
    271 
    272     if (options_.chargram_orders.empty()) {
    273       result.push_back(HashToken(feature_word));
    274     } else {
    275       // Generate the character-grams.
    276       for (int chargram_order : options_.chargram_orders) {
    277         UnicodeText::const_iterator it_start = feature_word_unicode.begin();
    278         UnicodeText::const_iterator it_end = feature_word_unicode.end();
    279         if (chargram_order == 1) {
    280           ++it_start;
    281           --it_end;
    282         }
    283 
    284         UnicodeText::const_iterator it_chargram_start = it_start;
    285         UnicodeText::const_iterator it_chargram_end = it_start;
    286         bool chargram_is_complete = true;
    287         for (int i = 0; i < chargram_order; ++i) {
    288           if (it_chargram_end == it_end) {
    289             chargram_is_complete = false;
    290             break;
    291           }
    292           ++it_chargram_end;
    293         }
    294         if (!chargram_is_complete) {
    295           continue;
    296         }
    297 
    298         for (; it_chargram_end <= it_end;
    299              ++it_chargram_start, ++it_chargram_end) {
    300           const int length_bytes =
    301               it_chargram_end.utf8_data() - it_chargram_start.utf8_data();
    302           result.push_back(HashToken(
    303               StringPiece(it_chargram_start.utf8_data(), length_bytes)));
    304         }
    305       }
    306     }
    307   }
    308   return result;
    309 }
    310 
    311 }  // namespace libtextclassifier2
    312