1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "token-feature-extractor.h" 18 19 #include <cctype> 20 #include <string> 21 22 #include "util/base/logging.h" 23 #include "util/hash/farmhash.h" 24 #include "util/strings/stringpiece.h" 25 #include "util/utf8/unicodetext.h" 26 27 namespace libtextclassifier2 { 28 29 namespace { 30 31 std::string RemapTokenAscii(const std::string& token, 32 const TokenFeatureExtractorOptions& options) { 33 if (!options.remap_digits && !options.lowercase_tokens) { 34 return token; 35 } 36 37 std::string copy = token; 38 for (int i = 0; i < token.size(); ++i) { 39 if (options.remap_digits && isdigit(copy[i])) { 40 copy[i] = '0'; 41 } 42 if (options.lowercase_tokens) { 43 copy[i] = tolower(copy[i]); 44 } 45 } 46 return copy; 47 } 48 49 void RemapTokenUnicode(const std::string& token, 50 const TokenFeatureExtractorOptions& options, 51 const UniLib& unilib, UnicodeText* remapped) { 52 if (!options.remap_digits && !options.lowercase_tokens) { 53 // Leave remapped untouched. 54 return; 55 } 56 57 UnicodeText word = UTF8ToUnicodeText(token, /*do_copy=*/false); 58 remapped->clear(); 59 for (auto it = word.begin(); it != word.end(); ++it) { 60 if (options.remap_digits && unilib.IsDigit(*it)) { 61 remapped->AppendCodepoint('0'); 62 } else if (options.lowercase_tokens) { 63 remapped->AppendCodepoint(unilib.ToLower(*it)); 64 } else { 65 remapped->AppendCodepoint(*it); 66 } 67 } 68 } 69 70 } // namespace 71 72 TokenFeatureExtractor::TokenFeatureExtractor( 73 const TokenFeatureExtractorOptions& options, const UniLib& unilib) 74 : options_(options), unilib_(unilib) { 75 for (const std::string& pattern : options.regexp_features) { 76 regex_patterns_.push_back(std::unique_ptr<UniLib::RegexPattern>( 77 unilib_.CreateRegexPattern(UTF8ToUnicodeText( 78 pattern.c_str(), pattern.size(), /*do_copy=*/false)))); 79 } 80 } 81 82 bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span, 83 std::vector<int>* sparse_features, 84 std::vector<float>* dense_features) const { 85 if (!dense_features) { 86 return false; 87 } 88 if (sparse_features) { 89 *sparse_features = ExtractCharactergramFeatures(token); 90 } 91 *dense_features = ExtractDenseFeatures(token, is_in_span); 92 return true; 93 } 94 95 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeatures( 96 const Token& token) const { 97 if (options_.unicode_aware_features) { 98 return ExtractCharactergramFeaturesUnicode(token); 99 } else { 100 return ExtractCharactergramFeaturesAscii(token); 101 } 102 } 103 104 std::vector<float> TokenFeatureExtractor::ExtractDenseFeatures( 105 const Token& token, bool is_in_span) const { 106 std::vector<float> dense_features; 107 108 if (options_.extract_case_feature) { 109 if (options_.unicode_aware_features) { 110 UnicodeText token_unicode = 111 UTF8ToUnicodeText(token.value, /*do_copy=*/false); 112 const bool is_upper = unilib_.IsUpper(*token_unicode.begin()); 113 if (!token.value.empty() && is_upper) { 114 dense_features.push_back(1.0); 115 } else { 116 dense_features.push_back(-1.0); 117 } 118 } else { 119 if (!token.value.empty() && isupper(*token.value.begin())) { 120 dense_features.push_back(1.0); 121 } else { 122 dense_features.push_back(-1.0); 123 } 124 } 125 } 126 127 if (options_.extract_selection_mask_feature) { 128 if (is_in_span) { 129 dense_features.push_back(1.0); 130 } else { 131 if (options_.unicode_aware_features) { 132 dense_features.push_back(-1.0); 133 } else { 134 dense_features.push_back(0.0); 135 } 136 } 137 } 138 139 // Add regexp features. 140 if (!regex_patterns_.empty()) { 141 UnicodeText token_unicode = 142 UTF8ToUnicodeText(token.value, /*do_copy=*/false); 143 for (int i = 0; i < regex_patterns_.size(); ++i) { 144 if (!regex_patterns_[i].get()) { 145 dense_features.push_back(-1.0); 146 continue; 147 } 148 auto matcher = regex_patterns_[i]->Matcher(token_unicode); 149 int status; 150 if (matcher->Matches(&status)) { 151 dense_features.push_back(1.0); 152 } else { 153 dense_features.push_back(-1.0); 154 } 155 } 156 } 157 158 return dense_features; 159 } 160 161 int TokenFeatureExtractor::HashToken(StringPiece token) const { 162 if (options_.allowed_chargrams.empty()) { 163 return tc2farmhash::Fingerprint64(token) % options_.num_buckets; 164 } else { 165 // Padding and out-of-vocabulary tokens have extra buckets reserved because 166 // they are special and important tokens, and we don't want them to share 167 // embedding with other charactergrams. 168 // TODO(zilka): Experimentally verify. 169 const int kNumExtraBuckets = 2; 170 const std::string token_string = token.ToString(); 171 if (token_string == "<PAD>") { 172 return 1; 173 } else if (options_.allowed_chargrams.find(token_string) == 174 options_.allowed_chargrams.end()) { 175 return 0; // Out-of-vocabulary. 176 } else { 177 return (tc2farmhash::Fingerprint64(token) % 178 (options_.num_buckets - kNumExtraBuckets)) + 179 kNumExtraBuckets; 180 } 181 } 182 } 183 184 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesAscii( 185 const Token& token) const { 186 std::vector<int> result; 187 if (token.is_padding || token.value.empty()) { 188 result.push_back(HashToken("<PAD>")); 189 } else { 190 const std::string word = RemapTokenAscii(token.value, options_); 191 192 // Trim words that are over max_word_length characters. 193 const int max_word_length = options_.max_word_length; 194 std::string feature_word; 195 if (word.size() > max_word_length) { 196 feature_word = 197 "^" + word.substr(0, max_word_length / 2) + "\1" + 198 word.substr(word.size() - max_word_length / 2, max_word_length / 2) + 199 "$"; 200 } else { 201 // Add a prefix and suffix to the word. 202 feature_word = "^" + word + "$"; 203 } 204 205 // Upper-bound the number of charactergram extracted to avoid resizing. 206 result.reserve(options_.chargram_orders.size() * feature_word.size()); 207 208 if (options_.chargram_orders.empty()) { 209 result.push_back(HashToken(feature_word)); 210 } else { 211 // Generate the character-grams. 212 for (int chargram_order : options_.chargram_orders) { 213 if (chargram_order == 1) { 214 for (int i = 1; i < feature_word.size() - 1; ++i) { 215 result.push_back( 216 HashToken(StringPiece(feature_word, /*offset=*/i, /*len=*/1))); 217 } 218 } else { 219 for (int i = 0; 220 i < static_cast<int>(feature_word.size()) - chargram_order + 1; 221 ++i) { 222 result.push_back(HashToken(StringPiece(feature_word, /*offset=*/i, 223 /*len=*/chargram_order))); 224 } 225 } 226 } 227 } 228 } 229 return result; 230 } 231 232 std::vector<int> TokenFeatureExtractor::ExtractCharactergramFeaturesUnicode( 233 const Token& token) const { 234 std::vector<int> result; 235 if (token.is_padding || token.value.empty()) { 236 result.push_back(HashToken("<PAD>")); 237 } else { 238 UnicodeText word = UTF8ToUnicodeText(token.value, /*do_copy=*/false); 239 RemapTokenUnicode(token.value, options_, unilib_, &word); 240 241 // Trim the word if needed by finding a left-cut point and right-cut point. 242 auto left_cut = word.begin(); 243 auto right_cut = word.end(); 244 for (int i = 0; i < options_.max_word_length / 2; i++) { 245 if (left_cut < right_cut) { 246 ++left_cut; 247 } 248 if (left_cut < right_cut) { 249 --right_cut; 250 } 251 } 252 253 std::string feature_word; 254 if (left_cut == right_cut) { 255 feature_word = "^" + word.UTF8Substring(word.begin(), word.end()) + "$"; 256 } else { 257 // clang-format off 258 feature_word = "^" + 259 word.UTF8Substring(word.begin(), left_cut) + 260 "\1" + 261 word.UTF8Substring(right_cut, word.end()) + 262 "$"; 263 // clang-format on 264 } 265 266 const UnicodeText feature_word_unicode = 267 UTF8ToUnicodeText(feature_word, /*do_copy=*/false); 268 269 // Upper-bound the number of charactergram extracted to avoid resizing. 270 result.reserve(options_.chargram_orders.size() * feature_word.size()); 271 272 if (options_.chargram_orders.empty()) { 273 result.push_back(HashToken(feature_word)); 274 } else { 275 // Generate the character-grams. 276 for (int chargram_order : options_.chargram_orders) { 277 UnicodeText::const_iterator it_start = feature_word_unicode.begin(); 278 UnicodeText::const_iterator it_end = feature_word_unicode.end(); 279 if (chargram_order == 1) { 280 ++it_start; 281 --it_end; 282 } 283 284 UnicodeText::const_iterator it_chargram_start = it_start; 285 UnicodeText::const_iterator it_chargram_end = it_start; 286 bool chargram_is_complete = true; 287 for (int i = 0; i < chargram_order; ++i) { 288 if (it_chargram_end == it_end) { 289 chargram_is_complete = false; 290 break; 291 } 292 ++it_chargram_end; 293 } 294 if (!chargram_is_complete) { 295 continue; 296 } 297 298 for (; it_chargram_end <= it_end; 299 ++it_chargram_start, ++it_chargram_end) { 300 const int length_bytes = 301 it_chargram_end.utf8_data() - it_chargram_start.utf8_data(); 302 result.push_back(HashToken( 303 StringPiece(it_chargram_start.utf8_data(), length_bytes))); 304 } 305 } 306 } 307 } 308 return result; 309 } 310 311 } // namespace libtextclassifier2 312