Home | History | Annotate | Download | only in smartselect
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "smartselect/text-classification-model.h"
     18 
     19 #include <cmath>
     20 #include <iterator>
     21 #include <numeric>
     22 
     23 #include "common/embedding-network.h"
     24 #include "common/feature-extractor.h"
     25 #include "common/memory_image/embedding-network-params-from-image.h"
     26 #include "common/memory_image/memory-image-reader.h"
     27 #include "common/mmap.h"
     28 #include "common/softmax.h"
     29 #include "smartselect/text-classification-model.pb.h"
     30 #include "util/base/logging.h"
     31 #include "util/utf8/unicodetext.h"
     32 #include "unicode/uchar.h"
     33 
     34 namespace libtextclassifier {
     35 
     36 using nlp_core::EmbeddingNetwork;
     37 using nlp_core::EmbeddingNetworkProto;
     38 using nlp_core::FeatureVector;
     39 using nlp_core::MemoryImageReader;
     40 using nlp_core::MmapFile;
     41 using nlp_core::MmapHandle;
     42 using nlp_core::ScopedMmap;
     43 
     44 namespace {
     45 
     46 int CountDigits(const std::string& str, CodepointSpan selection_indices) {
     47   int count = 0;
     48   int i = 0;
     49   const UnicodeText unicode_str = UTF8ToUnicodeText(str, /*do_copy=*/false);
     50   for (auto it = unicode_str.begin(); it != unicode_str.end(); ++it, ++i) {
     51     if (i >= selection_indices.first && i < selection_indices.second &&
     52         u_isdigit(*it)) {
     53       ++count;
     54     }
     55   }
     56   return count;
     57 }
     58 
     59 }  // namespace
     60 
     61 CodepointSpan TextClassificationModel::StripPunctuation(
     62     CodepointSpan selection, const std::string& context) const {
     63   UnicodeText context_unicode = UTF8ToUnicodeText(context, /*do_copy=*/false);
     64   int context_length =
     65       std::distance(context_unicode.begin(), context_unicode.end());
     66 
     67   // Check that the indices are valid.
     68   if (selection.first < 0 || selection.first > context_length ||
     69       selection.second < 0 || selection.second > context_length) {
     70     return selection;
     71   }
     72 
     73   // Move the left border until we encounter a non-punctuation character.
     74   UnicodeText::const_iterator it_from_begin = context_unicode.begin();
     75   std::advance(it_from_begin, selection.first);
     76   for (; punctuation_to_strip_.find(*it_from_begin) !=
     77          punctuation_to_strip_.end();
     78        ++it_from_begin, ++selection.first) {
     79   }
     80 
     81   // Unless we are already at the end, move the right border until we encounter
     82   // a non-punctuation character.
     83   UnicodeText::const_iterator it_from_end = context_unicode.begin();
     84   std::advance(it_from_end, selection.second);
     85   if (it_from_begin != it_from_end) {
     86     --it_from_end;
     87     for (; punctuation_to_strip_.find(*it_from_end) !=
     88            punctuation_to_strip_.end();
     89          --it_from_end, --selection.second) {
     90     }
     91     return selection;
     92   } else {
     93     // When the token is all punctuation.
     94     return {0, 0};
     95   }
     96 }
     97 
     98 TextClassificationModel::TextClassificationModel(int fd) : mmap_(fd) {
     99   initialized_ = LoadModels(mmap_.handle());
    100   if (!initialized_) {
    101     TC_LOG(ERROR) << "Failed to load models";
    102     return;
    103   }
    104 
    105   selection_options_ = selection_params_->GetSelectionModelOptions();
    106   for (const int codepoint : selection_options_.punctuation_to_strip()) {
    107     punctuation_to_strip_.insert(codepoint);
    108   }
    109 
    110   sharing_options_ = selection_params_->GetSharingModelOptions();
    111 }
    112 
    113 namespace {
    114 
    115 // Converts sparse features vector to nlp_core::FeatureVector.
    116 void SparseFeaturesToFeatureVector(
    117     const std::vector<int> sparse_features,
    118     const nlp_core::NumericFeatureType& feature_type,
    119     nlp_core::FeatureVector* result) {
    120   for (int feature_id : sparse_features) {
    121     const int64 feature_value =
    122         nlp_core::FloatFeatureValue(feature_id, 1.0 / sparse_features.size())
    123             .discrete_value;
    124     result->add(const_cast<nlp_core::NumericFeatureType*>(&feature_type),
    125                 feature_value);
    126   }
    127 }
    128 
    129 // Returns a function that can be used for mapping sparse and dense features
    130 // to a float feature vector.
    131 // NOTE: The network object needs to be available at the time when the returned
    132 // function object is used.
    133 FeatureVectorFn CreateFeatureVectorFn(const EmbeddingNetwork& network,
    134                                       int sparse_embedding_size) {
    135   const nlp_core::NumericFeatureType feature_type("chargram_continuous", 0);
    136   return [&network, sparse_embedding_size, feature_type](
    137              const std::vector<int>& sparse_features,
    138              const std::vector<float>& dense_features, float* embedding) {
    139     nlp_core::FeatureVector feature_vector;
    140     SparseFeaturesToFeatureVector(sparse_features, feature_type,
    141                                   &feature_vector);
    142 
    143     if (network.GetEmbedding(feature_vector, 0, embedding)) {
    144       for (int i = 0; i < dense_features.size(); i++) {
    145         embedding[sparse_embedding_size + i] = dense_features[i];
    146       }
    147       return true;
    148     } else {
    149       return false;
    150     }
    151   };
    152 }
    153 
    154 void ParseMergedModel(const MmapHandle& mmap_handle,
    155                       const char** selection_model, int* selection_model_length,
    156                       const char** sharing_model, int* sharing_model_length) {
    157   // Read the length of the selection model.
    158   const char* model_data = reinterpret_cast<const char*>(mmap_handle.start());
    159   *selection_model_length =
    160       LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
    161   model_data += sizeof(*selection_model_length);
    162   *selection_model = model_data;
    163   model_data += *selection_model_length;
    164 
    165   *sharing_model_length =
    166       LittleEndian::ToHost32(*reinterpret_cast<const uint32*>(model_data));
    167   model_data += sizeof(*sharing_model_length);
    168   *sharing_model = model_data;
    169 }
    170 
    171 }  // namespace
    172 
    173 bool TextClassificationModel::LoadModels(const MmapHandle& mmap_handle) {
    174   if (!mmap_handle.ok()) {
    175     return false;
    176   }
    177 
    178   const char *selection_model, *sharing_model;
    179   int selection_model_length, sharing_model_length;
    180   ParseMergedModel(mmap_handle, &selection_model, &selection_model_length,
    181                    &sharing_model, &sharing_model_length);
    182 
    183   selection_params_.reset(
    184       ModelParamsBuilder(selection_model, selection_model_length, nullptr));
    185   if (!selection_params_.get()) {
    186     return false;
    187   }
    188   selection_network_.reset(new EmbeddingNetwork(selection_params_.get()));
    189   selection_feature_processor_.reset(
    190       new FeatureProcessor(selection_params_->GetFeatureProcessorOptions()));
    191   selection_feature_fn_ = CreateFeatureVectorFn(
    192       *selection_network_, selection_network_->EmbeddingSize(0));
    193 
    194   sharing_params_.reset(
    195       ModelParamsBuilder(sharing_model, sharing_model_length,
    196                          selection_params_->GetEmbeddingParams()));
    197   if (!sharing_params_.get()) {
    198     return false;
    199   }
    200   sharing_network_.reset(new EmbeddingNetwork(sharing_params_.get()));
    201   sharing_feature_processor_.reset(
    202       new FeatureProcessor(sharing_params_->GetFeatureProcessorOptions()));
    203   sharing_feature_fn_ = CreateFeatureVectorFn(
    204       *sharing_network_, sharing_network_->EmbeddingSize(0));
    205 
    206   return true;
    207 }
    208 
    209 bool ReadSelectionModelOptions(int fd, ModelOptions* model_options) {
    210   ScopedMmap mmap = ScopedMmap(fd);
    211   if (!mmap.handle().ok()) {
    212     TC_LOG(ERROR) << "Can't mmap.";
    213     return false;
    214   }
    215 
    216   const char *selection_model, *sharing_model;
    217   int selection_model_length, sharing_model_length;
    218   ParseMergedModel(mmap.handle(), &selection_model, &selection_model_length,
    219                    &sharing_model, &sharing_model_length);
    220 
    221   MemoryImageReader<EmbeddingNetworkProto> reader(selection_model,
    222                                                   selection_model_length);
    223 
    224   auto model_options_extension_id = model_options_in_embedding_network_proto;
    225   if (reader.trimmed_proto().HasExtension(model_options_extension_id)) {
    226     *model_options =
    227         reader.trimmed_proto().GetExtension(model_options_extension_id);
    228     return true;
    229   } else {
    230     return false;
    231   }
    232 }
    233 
    234 EmbeddingNetwork::Vector TextClassificationModel::InferInternal(
    235     const std::string& context, CodepointSpan span,
    236     const FeatureProcessor& feature_processor, const EmbeddingNetwork& network,
    237     const FeatureVectorFn& feature_vector_fn,
    238     std::vector<CodepointSpan>* selection_label_spans) const {
    239   std::vector<Token> tokens;
    240   int click_pos;
    241   std::unique_ptr<CachedFeatures> cached_features;
    242   const int embedding_size = network.EmbeddingSize(0);
    243   if (!feature_processor.ExtractFeatures(
    244           context, span, /*relative_click_span=*/{0, 0},
    245           CreateFeatureVectorFn(network, embedding_size),
    246           embedding_size + feature_processor.DenseFeaturesCount(), &tokens,
    247           &click_pos, &cached_features)) {
    248     TC_LOG(ERROR) << "Could not extract features.";
    249     return {};
    250   }
    251 
    252   VectorSpan<float> features;
    253   VectorSpan<Token> output_tokens;
    254   if (!cached_features->Get(click_pos, &features, &output_tokens)) {
    255     TC_LOG(ERROR) << "Could not extract features.";
    256     return {};
    257   }
    258 
    259   if (selection_label_spans != nullptr) {
    260     if (!feature_processor.SelectionLabelSpans(output_tokens,
    261                                                selection_label_spans)) {
    262       TC_LOG(ERROR) << "Could not get spans for selection labels.";
    263       return {};
    264     }
    265   }
    266 
    267   std::vector<float> scores;
    268   network.ComputeLogits(features, &scores);
    269   return scores;
    270 }
    271 
    272 CodepointSpan TextClassificationModel::SuggestSelection(
    273     const std::string& context, CodepointSpan click_indices) const {
    274   if (!initialized_) {
    275     TC_LOG(ERROR) << "Not initialized";
    276     return click_indices;
    277   }
    278 
    279   if (std::get<0>(click_indices) >= std::get<1>(click_indices)) {
    280     TC_LOG(ERROR) << "Trying to run SuggestSelection with invalid indices:"
    281                   << std::get<0>(click_indices) << " "
    282                   << std::get<1>(click_indices);
    283     return click_indices;
    284   }
    285 
    286   const UnicodeText context_unicode =
    287       UTF8ToUnicodeText(context, /*do_copy=*/false);
    288   const int context_length =
    289       std::distance(context_unicode.begin(), context_unicode.end());
    290   if (std::get<0>(click_indices) >= context_length ||
    291       std::get<1>(click_indices) > context_length) {
    292     return click_indices;
    293   }
    294 
    295   CodepointSpan result;
    296   if (selection_options_.enforce_symmetry()) {
    297     result = SuggestSelectionSymmetrical(context, click_indices);
    298   } else {
    299     float score;
    300     std::tie(result, score) = SuggestSelectionInternal(context, click_indices);
    301   }
    302 
    303   if (selection_options_.strip_punctuation()) {
    304     result = StripPunctuation(result, context);
    305   }
    306 
    307   return result;
    308 }
    309 
    310 namespace {
    311 
    312 std::pair<CodepointSpan, float> BestSelectionSpan(
    313     CodepointSpan original_click_indices, const std::vector<float>& scores,
    314     const std::vector<CodepointSpan>& selection_label_spans) {
    315   if (!scores.empty()) {
    316     const int prediction =
    317         std::max_element(scores.begin(), scores.end()) - scores.begin();
    318     std::pair<CodepointIndex, CodepointIndex> selection =
    319         selection_label_spans[prediction];
    320 
    321     if (selection.first == kInvalidIndex || selection.second == kInvalidIndex) {
    322       TC_LOG(ERROR) << "Invalid indices predicted, returning input: "
    323                     << prediction << " " << selection.first << " "
    324                     << selection.second;
    325       return {original_click_indices, -1.0};
    326     }
    327 
    328     return {{selection.first, selection.second}, scores[prediction]};
    329   } else {
    330     TC_LOG(ERROR) << "Returning default selection: scores.size() = "
    331                   << scores.size();
    332     return {original_click_indices, -1.0};
    333   }
    334 }
    335 
    336 }  // namespace
    337 
    338 std::pair<CodepointSpan, float>
    339 TextClassificationModel::SuggestSelectionInternal(
    340     const std::string& context, CodepointSpan click_indices) const {
    341   if (!initialized_) {
    342     TC_LOG(ERROR) << "Not initialized";
    343     return {click_indices, -1.0};
    344   }
    345 
    346   std::vector<CodepointSpan> selection_label_spans;
    347   EmbeddingNetwork::Vector scores = InferInternal(
    348       context, click_indices, *selection_feature_processor_,
    349       *selection_network_, selection_feature_fn_, &selection_label_spans);
    350   scores = nlp_core::ComputeSoftmax(scores);
    351 
    352   return BestSelectionSpan(click_indices, scores, selection_label_spans);
    353 }
    354 
    355 // Implements a greedy-search-like algorithm for making selections symmetric.
    356 //
    357 // Steps:
    358 // 1. Get a set of selection proposals from places around the clicked word.
    359 // 2. For each proposal (going from highest-scoring), check if the tokens that
    360 //    the proposal selects are still free, in which case it claims them, if a
    361 //    proposal that contains the clicked token is found, it is returned as the
    362 //    suggestion.
    363 //
    364 // This algorithm should ensure that if a selection is proposed, it does not
    365 // matter which word of it was tapped - all of them will lead to the same
    366 // selection.
    367 CodepointSpan TextClassificationModel::SuggestSelectionSymmetrical(
    368     const std::string& context, CodepointSpan click_indices) const {
    369   const int symmetry_context_size = selection_options_.symmetry_context_size();
    370   std::vector<Token> tokens;
    371   std::unique_ptr<CachedFeatures> cached_features;
    372   int click_index;
    373   int embedding_size = selection_network_->EmbeddingSize(0);
    374   if (!selection_feature_processor_->ExtractFeatures(
    375           context, click_indices, /*relative_click_span=*/
    376           {symmetry_context_size, symmetry_context_size + 1},
    377           selection_feature_fn_,
    378           embedding_size + selection_feature_processor_->DenseFeaturesCount(),
    379           &tokens, &click_index, &cached_features)) {
    380     TC_LOG(ERROR) << "Couldn't ExtractFeatures.";
    381     return click_indices;
    382   }
    383 
    384   // Scan in the symmetry context for selection span proposals.
    385   std::vector<std::pair<CodepointSpan, float>> proposals;
    386 
    387   for (int i = -symmetry_context_size; i < symmetry_context_size + 1; ++i) {
    388     const int token_index = click_index + i;
    389     if (token_index >= 0 && token_index < tokens.size() &&
    390         !tokens[token_index].is_padding) {
    391       float score;
    392       VectorSpan<float> features;
    393       VectorSpan<Token> output_tokens;
    394 
    395       CodepointSpan span;
    396       if (cached_features->Get(token_index, &features, &output_tokens)) {
    397         std::vector<float> scores;
    398         selection_network_->ComputeLogits(features, &scores);
    399 
    400         std::vector<CodepointSpan> selection_label_spans;
    401         if (selection_feature_processor_->SelectionLabelSpans(
    402                 output_tokens, &selection_label_spans)) {
    403           scores = nlp_core::ComputeSoftmax(scores);
    404           std::tie(span, score) =
    405               BestSelectionSpan(click_indices, scores, selection_label_spans);
    406           if (span.first != kInvalidIndex && span.second != kInvalidIndex &&
    407               score >= 0) {
    408             proposals.push_back({span, score});
    409           }
    410         }
    411       }
    412     }
    413   }
    414 
    415   // Sort selection span proposals by their respective probabilities.
    416   std::sort(
    417       proposals.begin(), proposals.end(),
    418       [](std::pair<CodepointSpan, float> a, std::pair<CodepointSpan, float> b) {
    419         return a.second > b.second;
    420       });
    421 
    422   // Go from the highest-scoring proposal and claim tokens. Tokens are marked as
    423   // claimed by the higher-scoring selection proposals, so that the
    424   // lower-scoring ones cannot use them. Returns the selection proposal if it
    425   // contains the clicked token.
    426   std::vector<int> used_tokens(tokens.size(), 0);
    427   for (auto span_result : proposals) {
    428     TokenSpan span = CodepointSpanToTokenSpan(tokens, span_result.first);
    429     if (span.first != kInvalidIndex && span.second != kInvalidIndex) {
    430       bool feasible = true;
    431       for (int i = span.first; i < span.second; i++) {
    432         if (used_tokens[i] != 0) {
    433           feasible = false;
    434           break;
    435         }
    436       }
    437 
    438       if (feasible) {
    439         if (span.first <= click_index && span.second > click_index) {
    440           return {span_result.first.first, span_result.first.second};
    441         }
    442         for (int i = span.first; i < span.second; i++) {
    443           used_tokens[i] = 1;
    444         }
    445       }
    446     }
    447   }
    448 
    449   return {click_indices.first, click_indices.second};
    450 }
    451 
    452 std::vector<std::pair<std::string, float>>
    453 TextClassificationModel::ClassifyText(const std::string& context,
    454                                       CodepointSpan selection_indices,
    455                                       int hint_flags) const {
    456   if (!initialized_) {
    457     TC_LOG(ERROR) << "Not initialized";
    458     return {};
    459   }
    460 
    461   if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
    462     TC_LOG(ERROR) << "Trying to run ClassifyText with invalid indices: "
    463                   << std::get<0>(selection_indices) << " "
    464                   << std::get<1>(selection_indices);
    465     return {};
    466   }
    467 
    468   if (hint_flags & SELECTION_IS_URL &&
    469       sharing_options_.always_accept_url_hint()) {
    470     return {{kUrlHintCollection, 1.0}};
    471   }
    472 
    473   if (hint_flags & SELECTION_IS_EMAIL &&
    474       sharing_options_.always_accept_email_hint()) {
    475     return {{kEmailHintCollection, 1.0}};
    476   }
    477 
    478   EmbeddingNetwork::Vector scores =
    479       InferInternal(context, selection_indices, *sharing_feature_processor_,
    480                     *sharing_network_, sharing_feature_fn_, nullptr);
    481   if (scores.empty() ||
    482       scores.size() != sharing_feature_processor_->NumCollections()) {
    483     TC_LOG(ERROR) << "Using default class: scores.size() = " << scores.size();
    484     return {};
    485   }
    486 
    487   scores = nlp_core::ComputeSoftmax(scores);
    488 
    489   std::vector<std::pair<std::string, float>> result;
    490   for (int i = 0; i < scores.size(); i++) {
    491     result.push_back(
    492         {sharing_feature_processor_->LabelToCollection(i), scores[i]});
    493   }
    494   std::sort(result.begin(), result.end(),
    495             [](const std::pair<std::string, float>& a,
    496                const std::pair<std::string, float>& b) {
    497               return a.second > b.second;
    498             });
    499 
    500   // Phone class sanity check.
    501   if (result.begin()->first == kPhoneCollection) {
    502     const int digit_count = CountDigits(context, selection_indices);
    503     if (digit_count < sharing_options_.phone_min_num_digits() ||
    504         digit_count > sharing_options_.phone_max_num_digits()) {
    505       return {{kOtherCollection, 1.0}};
    506     }
    507   }
    508 
    509   return result;
    510 }
    511 
    512 }  // namespace libtextclassifier
    513