1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "smartselect/text-classification-model.h" 18 19 #include <cmath> 20 #include <iterator> 21 #include <numeric> 22 23 #include "common/embedding-network.h" 24 #include "common/feature-extractor.h" 25 #include "common/memory_image/embedding-network-params-from-image.h" 26 #include "common/memory_image/memory-image-reader.h" 27 #include "common/mmap.h" 28 #include "common/softmax.h" 29 #include "smartselect/text-classification-model.pb.h" 30 #include "util/base/logging.h" 31 #include "util/utf8/unicodetext.h" 32 #include "unicode/uchar.h" 33 34 namespace libtextclassifier { 35 36 using nlp_core::EmbeddingNetwork; 37 using nlp_core::EmbeddingNetworkProto; 38 using nlp_core::FeatureVector; 39 using nlp_core::MemoryImageReader; 40 using nlp_core::MmapFile; 41 using nlp_core::MmapHandle; 42 using nlp_core::ScopedMmap; 43 44 namespace { 45 46 int CountDigits(const std::string& str, CodepointSpan selection_indices) { 47 int count = 0; 48 int i = 0; 49 const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false); 50 for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) { 51 if (i >= selection_indices.first && i < selection_indices.second && 52 u_isdigit(*it)) { 53 ++count; 54 } 55 } 56 return count; 57 } 58 59 } // namespace 60 61 CodepointSpan TextClassificationModel::StripPunctuation( 62 CodepointSpan selection, const std::string& context) const { 63 UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false); 64 int context_length = 65 std::distance(context_unicode.begin(), context_unicode.end()); 66 67 // Check that the indices are valid. 68 if (selection.first < 0 || selection.first > context_length || 69 selection.second < 0 || selection.second > context_length) { 70 return selection; 71 } 72 73 // Move the left border until we encounter a non-punctuation character. 74 UnicodeText::const_iterator it_from_begin = context_unicode.begin(); 75 std::advance(it_from_begin, selection.first); 76 for (; punctuation_to_strip_.find(*it_from_begin) != 77 punctuation_to_strip_.end(); 78 ++it_from_begin, ++selection.first) { 79 } 80 81 // Unless we are already at the end, move the right border until we encounter 82 // a non-punctuation character. 83 UnicodeText::const_iterator it_from_end = context_unicode.begin(); 84 std::advance(it_from_end, selection.second); 85 if (it_from_begin != it_from_end) { 86 --it_from_end; 87 for (; punctuation_to_strip_.find(*it_from_end) != 88 punctuation_to_strip_.end(); 89 --it_from_end, --selection.second) { 90 } 91 return selection; 92 } else { 93 // When the token is all punctuation. 94 return {0, 0}; 95 } 96 } 97 98 TextClassificationModel::TextClassificationModel(int fd) : mmap_(fd) { 99 initialized_ = LoadModels(mmap_.handle()); 100 if (!initialized_) { 101 TC_LOG(ERROR) << "Failed to load models"; 102 return; 103 } 104 105 selection_options_ = selection_params_->GetSelectionModelOptions(); 106 for (const int codepoint : selection_options_.punctuation_to_strip()) { 107 punctuation_to_strip_.insert(codepoint); 108 } 109 110 sharing_options_ = selection_params_->GetSharingModelOptions(); 111 } 112 113 namespace { 114 115 // Converts sparse features vector to nlp_core::FeatureVector. 116 void SparseFeaturesToFeatureVector( 117 const std::vector<int> sparse_features, 118 const nlp_core::NumericFeatureType& feature_type, 119 nlp_core::FeatureVector* result) { 120 for (int feature_id : sparse_features) { 121 const int64 feature_value = 122 nlp_core::FloatFeatureValue(feature_id, 1.0 / sparse_features.size()) 123 .discrete_value; 124 result->add(const_cast<nlp_core::NumericFeatureType*>(&feature_type), 125 feature_value); 126 } 127 } 128 129 // Returns a function that can be used for mapping sparse and dense features 130 // to a float feature vector. 131 // NOTE: The network object needs to be available at the time when the returned 132 // function object is used. 133 FeatureVectorFn CreateFeatureVectorFn(const EmbeddingNetwork& network, 134 int sparse_embedding_size) { 135 const nlp_core::NumericFeatureType feature_type("chargram_continuous", 0); 136 return [&network, sparse_embedding_size, feature_type]( 137 const std::vector<int>& sparse_features, 138 const std::vector<float>& dense_features, float* embedding) { 139 nlp_core::FeatureVector feature_vector; 140 SparseFeaturesToFeatureVector(sparse_features, feature_type, 141 &feature_vector); 142 143 if (network.GetEmbedding(feature_vector, 0, embedding)) { 144 for (int i = 0; i < dense_features.size(); i++) { 145 embedding[sparse_embedding_size + i] = dense_features[i]; 146 } 147 return true; 148 } else { 149 return false; 150 } 151 }; 152 } 153 154 void ParseMergedModel(const MmapHandle& mmap_handle, 155 const char** selection_model, int* selection_model_length, 156 const char** sharing_model, int* sharing_model_length) { 157 // Read the length of the selection model. 158 const char* model_data = reinterpret_cast<const char*>(mmap_handle.start()); 159 *selection_model_length = 160 LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data)); 161 model_data += sizeof(*selection_model_length); 162 *selection_model = model_data; 163 model_data += *selection_model_length; 164 165 *sharing_model_length = 166 LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data)); 167 model_data += sizeof(*sharing_model_length); 168 *sharing_model = model_data; 169 } 170 171 } // namespace 172 173 bool TextClassificationModel::LoadModels(const MmapHandle& mmap_handle) { 174 if (!mmap_handle.ok()) { 175 return false; 176 } 177 178 const char *selection_model, *sharing_model; 179 int selection_model_length, sharing_model_length; 180 ParseMergedModel(mmap_handle, &selection_model, &selection_model_length, 181 &sharing_model, &sharing_model_length); 182 183 selection_params_.reset( 184 ModelParamsBuilder(selection_model, selection_model_length, nullptr)); 185 if (!selection_params_.get()) { 186 return false; 187 } 188 selection_network_.reset(new EmbeddingNetwork(selection_params_.get())); 189 selection_feature_processor_.reset( 190 new FeatureProcessor(selection_params_->GetFeatureProcessorOptions())); 191 selection_feature_fn_ = CreateFeatureVectorFn( 192 *selection_network_, selection_network_->EmbeddingSize(0)); 193 194 sharing_params_.reset( 195 ModelParamsBuilder(sharing_model, sharing_model_length, 196 selection_params_->GetEmbeddingParams())); 197 if (!sharing_params_.get()) { 198 return false; 199 } 200 sharing_network_.reset(new EmbeddingNetwork(sharing_params_.get())); 201 sharing_feature_processor_.reset( 202 new FeatureProcessor(sharing_params_->GetFeatureProcessorOptions())); 203 sharing_feature_fn_ = CreateFeatureVectorFn( 204 *sharing_network_, sharing_network_->EmbeddingSize(0)); 205 206 return true; 207 } 208 209 bool ReadSelectionModelOptions(int fd, ModelOptions* model_options) { 210 ScopedMmap mmap = ScopedMmap(fd); 211 if (!mmap.handle().ok()) { 212 TC_LOG(ERROR) << "Can't mmap."; 213 return false; 214 } 215 216 const char *selection_model, *sharing_model; 217 int selection_model_length, sharing_model_length; 218 ParseMergedModel(mmap.handle(), &selection_model, &selection_model_length, 219 &sharing_model, &sharing_model_length); 220 221 MemoryImageReader<EmbeddingNetworkProto> reader(selection_model, 222 selection_model_length); 223 224 auto model_options_extension_id = model_options_in_embedding_network_proto; 225 if (reader.trimmed_proto().HasExtension(model_options_extension_id)) { 226 *model_options = 227 reader.trimmed_proto().GetExtension(model_options_extension_id); 228 return true; 229 } else { 230 return false; 231 } 232 } 233 234 EmbeddingNetwork::Vector TextClassificationModel::InferInternal( 235 const std::string& context, CodepointSpan span, 236 const FeatureProcessor& feature_processor, const EmbeddingNetwork& network, 237 const FeatureVectorFn& feature_vector_fn, 238 std::vector<CodepointSpan>* selection_label_spans) const { 239 std::vector<Token> tokens; 240 int click_pos; 241 std::unique_ptr<CachedFeatures> cached_features; 242 const int embedding_size = network.EmbeddingSize(0); 243 if (!feature_processor.ExtractFeatures( 244 context, span, /*relative_click_span=*/{0, 0}, 245 CreateFeatureVectorFn(network, embedding_size), 246 embedding_size + feature_processor.DenseFeaturesCount(), &tokens, 247 &click_pos, &cached_features)) { 248 TC_LOG(ERROR) << "Could not extract features."; 249 return {}; 250 } 251 252 VectorSpan<float> features; 253 VectorSpan<Token> output_tokens; 254 if (!cached_features->Get(click_pos, &features, &output_tokens)) { 255 TC_LOG(ERROR) << "Could not extract features."; 256 return {}; 257 } 258 259 if (selection_label_spans != nullptr) { 260 if (!feature_processor.SelectionLabelSpans(output_tokens, 261 selection_label_spans)) { 262 TC_LOG(ERROR) << "Could not get spans for selection labels."; 263 return {}; 264 } 265 } 266 267 std::vector<float> scores; 268 network.ComputeLogits(features, &scores); 269 return scores; 270 } 271 272 CodepointSpan TextClassificationModel::SuggestSelection( 273 const std::string& context, CodepointSpan click_indices) const { 274 if (!initialized_) { 275 TC_LOG(ERROR) << "Not initialized"; 276 return click_indices; 277 } 278 279 if (std::get<0>(click_indices) >= std::get<1>(click_indices)) { 280 TC_LOG(ERROR) << "Trying to run SuggestSelection with invalid indices:" 281 << std::get<0>(click_indices) << " " 282 << std::get<1>(click_indices); 283 return click_indices; 284 } 285 286 const UnicodeText context_unicode = 287 UTF8ToUnicodeText(context, /*do_copy=*/false); 288 const int context_length = 289 std::distance(context_unicode.begin(), context_unicode.end()); 290 if (std::get<0>(click_indices) >= context_length || 291 std::get<1>(click_indices) > context_length) { 292 return click_indices; 293 } 294 295 CodepointSpan result; 296 if (selection_options_.enforce_symmetry()) { 297 result = SuggestSelectionSymmetrical(context, click_indices); 298 } else { 299 float score; 300 std::tie(result, score) = SuggestSelectionInternal(context, click_indices); 301 } 302 303 if (selection_options_.strip_punctuation()) { 304 result = StripPunctuation(result, context); 305 } 306 307 return result; 308 } 309 310 namespace { 311 312 std::pair<CodepointSpan, float> BestSelectionSpan( 313 CodepointSpan original_click_indices, const std::vector<float>& scores, 314 const std::vector<CodepointSpan>& selection_label_spans) { 315 if (!scores.empty()) { 316 const int prediction = 317 std::max_element(scores.begin(), scores.end()) - scores.begin(); 318 std::pair<CodepointIndex, CodepointIndex> selection = 319 selection_label_spans[prediction]; 320 321 if (selection.first == kInvalidIndex || selection.second == kInvalidIndex) { 322 TC_LOG(ERROR) << "Invalid indices predicted, returning input: " 323 << prediction << " " << selection.first << " " 324 << selection.second; 325 return {original_click_indices, -1.0}; 326 } 327 328 return {{selection.first, selection.second}, scores[prediction]}; 329 } else { 330 TC_LOG(ERROR) << "Returning default selection: scores.size() = " 331 << scores.size(); 332 return {original_click_indices, -1.0}; 333 } 334 } 335 336 } // namespace 337 338 std::pair<CodepointSpan, float> 339 TextClassificationModel::SuggestSelectionInternal( 340 const std::string& context, CodepointSpan click_indices) const { 341 if (!initialized_) { 342 TC_LOG(ERROR) << "Not initialized"; 343 return {click_indices, -1.0}; 344 } 345 346 std::vector<CodepointSpan> selection_label_spans; 347 EmbeddingNetwork::Vector scores = InferInternal( 348 context, click_indices, *selection_feature_processor_, 349 *selection_network_, selection_feature_fn_, &selection_label_spans); 350 scores = nlp_core::ComputeSoftmax(scores); 351 352 return BestSelectionSpan(click_indices, scores, selection_label_spans); 353 } 354 355 // Implements a greedy-search-like algorithm for making selections symmetric. 356 // 357 // Steps: 358 // 1. Get a set of selection proposals from places around the clicked word. 359 // 2. For each proposal (going from highest-scoring), check if the tokens that 360 // the proposal selects are still free, in which case it claims them, if a 361 // proposal that contains the clicked token is found, it is returned as the 362 // suggestion. 363 // 364 // This algorithm should ensure that if a selection is proposed, it does not 365 // matter which word of it was tapped - all of them will lead to the same 366 // selection. 367 CodepointSpan TextClassificationModel::SuggestSelectionSymmetrical( 368 const std::string& context, CodepointSpan click_indices) const { 369 const int symmetry_context_size = selection_options_.symmetry_context_size(); 370 std::vector<Token> tokens; 371 std::unique_ptr<CachedFeatures> cached_features; 372 int click_index; 373 int embedding_size = selection_network_->EmbeddingSize(0); 374 if (!selection_feature_processor_->ExtractFeatures( 375 context, click_indices, /*relative_click_span=*/ 376 {symmetry_context_size, symmetry_context_size + 1}, 377 selection_feature_fn_, 378 embedding_size + selection_feature_processor_->DenseFeaturesCount(), 379 &tokens, &click_index, &cached_features)) { 380 TC_LOG(ERROR) << "Couldn't ExtractFeatures."; 381 return click_indices; 382 } 383 384 // Scan in the symmetry context for selection span proposals. 385 std::vector<std::pair<CodepointSpan, float>> proposals; 386 387 for (int i = -symmetry_context_size; i < symmetry_context_size + 1; ++i) { 388 const int token_index = click_index + i; 389 if (token_index >= 0 && token_index < tokens.size() && 390 !tokens[token_index].is_padding) { 391 float score; 392 VectorSpan<float> features; 393 VectorSpan<Token> output_tokens; 394 395 CodepointSpan span; 396 if (cached_features->Get(token_index, &features, &output_tokens)) { 397 std::vector<float> scores; 398 selection_network_->ComputeLogits(features, &scores); 399 400 std::vector<CodepointSpan> selection_label_spans; 401 if (selection_feature_processor_->SelectionLabelSpans( 402 output_tokens, &selection_label_spans)) { 403 scores = nlp_core::ComputeSoftmax(scores); 404 std::tie(span, score) = 405 BestSelectionSpan(click_indices, scores, selection_label_spans); 406 if (span.first != kInvalidIndex && span.second != kInvalidIndex && 407 score >= 0) { 408 proposals.push_back({span, score}); 409 } 410 } 411 } 412 } 413 } 414 415 // Sort selection span proposals by their respective probabilities. 416 std::sort( 417 proposals.begin(), proposals.end(), 418 [](std::pair<CodepointSpan, float> a, std::pair<CodepointSpan, float> b) { 419 return a.second > b.second; 420 }); 421 422 // Go from the highest-scoring proposal and claim tokens. Tokens are marked as 423 // claimed by the higher-scoring selection proposals, so that the 424 // lower-scoring ones cannot use them. Returns the selection proposal if it 425 // contains the clicked token. 426 std::vector<int> used_tokens(tokens.size(), 0); 427 for (auto span_result : proposals) { 428 TokenSpan span = CodepointSpanToTokenSpan(tokens, span_result.first); 429 if (span.first != kInvalidIndex && span.second != kInvalidIndex) { 430 bool feasible = true; 431 for (int i = span.first; i < span.second; i++) { 432 if (used_tokens[i] != 0) { 433 feasible = false; 434 break; 435 } 436 } 437 438 if (feasible) { 439 if (span.first <= click_index && span.second > click_index) { 440 return {span_result.first.first, span_result.first.second}; 441 } 442 for (int i = span.first; i < span.second; i++) { 443 used_tokens[i] = 1; 444 } 445 } 446 } 447 } 448 449 return {click_indices.first, click_indices.second}; 450 } 451 452 std::vector<std::pair<std::string, float>> 453 TextClassificationModel::ClassifyText(const std::string& context, 454 CodepointSpan selection_indices, 455 int hint_flags) const { 456 if (!initialized_) { 457 TC_LOG(ERROR) << "Not initialized"; 458 return {}; 459 } 460 461 if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) { 462 TC_LOG(ERROR) << "Trying to run ClassifyText with invalid indices: " 463 << std::get<0>(selection_indices) << " " 464 << std::get<1>(selection_indices); 465 return {}; 466 } 467 468 if (hint_flags & SELECTION_IS_URL && 469 sharing_options_.always_accept_url_hint()) { 470 return {{kUrlHintCollection, 1.0}}; 471 } 472 473 if (hint_flags & SELECTION_IS_EMAIL && 474 sharing_options_.always_accept_email_hint()) { 475 return {{kEmailHintCollection, 1.0}}; 476 } 477 478 EmbeddingNetwork::Vector scores = 479 InferInternal(context, selection_indices, *sharing_feature_processor_, 480 *sharing_network_, sharing_feature_fn_, nullptr); 481 if (scores.empty() || 482 scores.size() != sharing_feature_processor_->NumCollections()) { 483 TC_LOG(ERROR) << "Using default class: scores.size() = " << scores.size(); 484 return {}; 485 } 486 487 scores = nlp_core::ComputeSoftmax(scores); 488 489 std::vector<std::pair<std::string, float>> result; 490 for (int i = 0; i < scores.size(); i++) { 491 result.push_back( 492 {sharing_feature_processor_->LabelToCollection(i), scores[i]}); 493 } 494 std::sort(result.begin(), result.end(), 495 [](const std::pair<std::string, float>& a, 496 const std::pair<std::string, float>& b) { 497 return a.second > b.second; 498 }); 499 500 // Phone class sanity check. 501 if (result.begin()->first == kPhoneCollection) { 502 const int digit_count = CountDigits(context, selection_indices); 503 if (digit_count < sharing_options_.phone_min_num_digits() || 504 digit_count > sharing_options_.phone_max_num_digits()) { 505 return {{kOtherCollection, 1.0}}; 506 } 507 } 508 509 return result; 510 } 511 512 } // namespace libtextclassifier 513