Home | History | Annotate | Download | only in lang_id
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "lang_id/lang-id.h"
     18 
     19 #include <stdio.h>
     20 
     21 #include <algorithm>
     22 #include <memory>
     23 #include <string>
     24 #include <unordered_map>
     25 #include <vector>
     26 
     27 #include "lang_id/common/embedding-feature-interface.h"
     28 #include "lang_id/common/embedding-network-params.h"
     29 #include "lang_id/common/embedding-network.h"
     30 #include "lang_id/common/fel/feature-extractor.h"
     31 #include "lang_id/common/lite_base/logging.h"
     32 #include "lang_id/common/lite_strings/numbers.h"
     33 #include "lang_id/common/lite_strings/str-split.h"
     34 #include "lang_id/common/lite_strings/stringpiece.h"
     35 #include "lang_id/common/math/algorithm.h"
     36 #include "lang_id/common/math/softmax.h"
     37 #include "lang_id/custom-tokenizer.h"
     38 #include "lang_id/features/light-sentence-features.h"
     39 #include "lang_id/light-sentence.h"
     40 
     41 namespace libtextclassifier3 {
     42 namespace mobile {
     43 namespace lang_id {
     44 
     45 namespace {
     46 // Default value for the confidence threshold.  If the confidence of the top
     47 // prediction is below this threshold, then FindLanguage() returns
     48 // LangId::kUnknownLanguageCode.  Note: this is just a default value; if the
     49 // TaskSpec from the model specifies a "reliability_thresh" parameter, then we
     50 // use that value instead.  Note: for legacy reasons, our code and comments use
     51 // the terms "confidence", "probability" and "reliability" equivalently.
     52 static const float kDefaultConfidenceThreshold = 0.50f;
     53 }  // namespace
     54 
     55 // Class that performs all work behind LangId.
     56 class LangIdImpl {
     57  public:
     58   explicit LangIdImpl(std::unique_ptr<ModelProvider> model_provider)
     59       : model_provider_(std::move(model_provider)),
     60         lang_id_brain_interface_("language_identifier") {
     61     // Note: in the code below, we set valid_ to true only if all initialization
     62     // steps completed successfully.  Otherwise, we return early, leaving valid_
     63     // to its default value false.
     64     if (!model_provider_ || !model_provider_->is_valid()) {
     65       SAFTM_LOG(ERROR) << "Invalid model provider";
     66       return;
     67     }
     68 
     69     auto *nn_params = model_provider_->GetNnParams();
     70     if (!nn_params) {
     71       SAFTM_LOG(ERROR) << "No NN params";
     72       return;
     73     }
     74     network_.reset(new EmbeddingNetwork(nn_params));
     75 
     76     languages_ = model_provider_->GetLanguages();
     77     if (languages_.empty()) {
     78       SAFTM_LOG(ERROR) << "No known languages";
     79       return;
     80     }
     81 
     82     TaskContext context = *model_provider_->GetTaskContext();
     83     if (!Setup(&context)) {
     84       SAFTM_LOG(ERROR) << "Unable to Setup() LangId";
     85       return;
     86     }
     87     if (!Init(&context)) {
     88       SAFTM_LOG(ERROR) << "Unable to Init() LangId";
     89       return;
     90     }
     91     valid_ = true;
     92   }
     93 
     94   string FindLanguage(StringPiece text) const {
     95     // NOTE: it would be wasteful to implement this method in terms of
     96     // FindLanguages().  We just need the most likely language and its
     97     // probability; no need to compute (and allocate) a vector of pairs for all
     98     // languages, nor to compute probabilities for all non-top languages.
     99     if (!is_valid()) {
    100       return LangId::kUnknownLanguageCode;
    101     }
    102 
    103     std::vector<float> scores;
    104     ComputeScores(text, &scores);
    105 
    106     int prediction_id = GetArgMax(scores);
    107     const string language = GetLanguageForSoftmaxLabel(prediction_id);
    108     float probability = ComputeSoftmaxProbability(scores, prediction_id);
    109     SAFTM_DLOG(INFO) << "Predicted " << language
    110                      << " with prob: " << probability << " for \"" << text
    111                      << "\"";
    112 
    113     // Find confidence threshold for language.
    114     float threshold = default_threshold_;
    115     auto it = per_lang_thresholds_.find(language);
    116     if (it != per_lang_thresholds_.end()) {
    117       threshold = it->second;
    118     }
    119     if (probability < threshold) {
    120       SAFTM_DLOG(INFO) << "  below threshold => "
    121                        << LangId::kUnknownLanguageCode;
    122       return LangId::kUnknownLanguageCode;
    123     }
    124     return language;
    125   }
    126 
    127   void FindLanguages(StringPiece text, LangIdResult *result) const {
    128     if (result == nullptr) return;
    129 
    130     result->predictions.clear();
    131     if (!is_valid()) {
    132       result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
    133       return;
    134     }
    135 
    136     std::vector<float> scores;
    137     ComputeScores(text, &scores);
    138 
    139     // Compute and sort softmax in descending order by probability and convert
    140     // IDs to language code strings.  When probabilities are equal, we sort by
    141     // language code string in ascending order.
    142     std::vector<float> softmax = ComputeSoftmax(scores);
    143 
    144     for (int i = 0; i < softmax.size(); ++i) {
    145       result->predictions.emplace_back(GetLanguageForSoftmaxLabel(i),
    146                                        softmax[i]);
    147     }
    148 
    149     // Sort the resulting language predictions by probability in descending
    150     // order.
    151     std::sort(result->predictions.begin(), result->predictions.end(),
    152               [](const std::pair<string, float> &a,
    153                  const std::pair<string, float> &b) {
    154                 if (a.second == b.second) {
    155                   return a.first.compare(b.first) < 0;
    156                 } else {
    157                   return a.second > b.second;
    158                 }
    159               });
    160   }
    161 
    162   bool is_valid() const { return valid_; }
    163 
    164   int GetModelVersion() const { return model_version_; }
    165 
    166   // Returns a property stored in the model file.
    167   template <typename T, typename R>
    168   R GetProperty(const string &property, T default_value) const {
    169     return model_provider_->GetTaskContext()->Get(property, default_value);
    170   }
    171 
    172  private:
    173   bool Setup(TaskContext *context) {
    174     tokenizer_.Setup(context);
    175     if (!lang_id_brain_interface_.SetupForProcessing(context)) return false;
    176     default_threshold_ =
    177         context->Get("reliability_thresh", kDefaultConfidenceThreshold);
    178 
    179     // Parse task parameter "per_lang_reliability_thresholds", fill
    180     // per_lang_thresholds_.
    181     const string thresholds_str =
    182         context->Get("per_lang_reliability_thresholds", "");
    183     std::vector<StringPiece> tokens = LiteStrSplit(thresholds_str, ',');
    184     for (const auto &token : tokens) {
    185       if (token.empty()) continue;
    186       std::vector<StringPiece> parts = LiteStrSplit(token, '=');
    187       float threshold = 0.0f;
    188       if ((parts.size() == 2) && LiteAtof(parts[1], &threshold)) {
    189         per_lang_thresholds_[string(parts[0])] = threshold;
    190       } else {
    191         SAFTM_LOG(ERROR) << "Broken token: \"" << token << "\"";
    192       }
    193     }
    194     model_version_ = context->Get("model_version", model_version_);
    195     return true;
    196   }
    197 
    198   bool Init(TaskContext *context) {
    199     return lang_id_brain_interface_.InitForProcessing(context);
    200   }
    201 
    202   // Extracts features for |text|, runs them through the feed-forward neural
    203   // network, and computes the output scores (activations from the last layer).
    204   // These scores can be used to compute the softmax probabilities for our
    205   // labels (in this case, the languages).
    206   void ComputeScores(StringPiece text, std::vector<float> *scores) const {
    207     // Create a Sentence storing the input text.
    208     LightSentence sentence;
    209     tokenizer_.Tokenize(text, &sentence);
    210 
    211     std::vector<FeatureVector> features =
    212         lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
    213 
    214     // Run feed-forward neural network to compute scores.
    215     network_->ComputeFinalScores(features, scores);
    216   }
    217 
    218   // Returns language code for a softmax label.  See comments for languages_
    219   // field.  If label is out of range, returns LangId::kUnknownLanguageCode.
    220   string GetLanguageForSoftmaxLabel(int label) const {
    221     if ((label >= 0) && (label < languages_.size())) {
    222       return languages_[label];
    223     } else {
    224       SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
    225                        << languages_.size() << ")";
    226       return LangId::kUnknownLanguageCode;
    227     }
    228   }
    229 
    230   std::unique_ptr<ModelProvider> model_provider_;
    231 
    232   TokenizerForLangId tokenizer_;
    233 
    234   EmbeddingFeatureInterface<LightSentenceExtractor, LightSentence>
    235       lang_id_brain_interface_;
    236 
    237   // Neural network to use for scoring.
    238   std::unique_ptr<EmbeddingNetwork> network_;
    239 
    240   // True if this object is ready to perform language predictions.
    241   bool valid_ = false;
    242 
    243   // Only predictions with a probability (confidence) above this threshold are
    244   // reported.  Otherwise, we report LangId::kUnknownLanguageCode.
    245   float default_threshold_ = kDefaultConfidenceThreshold;
    246 
    247   std::unordered_map<string, float> per_lang_thresholds_;
    248 
    249   // Recognized languages: softmax label i means languages_[i] (something like
    250   // "en", "fr", "ru", etc).
    251   std::vector<string> languages_;
    252 
    253   // Version of the model used by this LangIdImpl object.  Zero means that the
    254   // model version could not be determined.
    255   int model_version_ = 0;
    256 };
    257 
    258 const char LangId::kUnknownLanguageCode[] = "und";
    259 
    260 LangId::LangId(std::unique_ptr<ModelProvider> model_provider)
    261     : pimpl_(new LangIdImpl(std::move(model_provider))) {}
    262 
    263 LangId::~LangId() = default;
    264 
    265 string LangId::FindLanguage(const char *data, size_t num_bytes) const {
    266   StringPiece text(data, num_bytes);
    267   return pimpl_->FindLanguage(text);
    268 }
    269 
    270 void LangId::FindLanguages(const char *data, size_t num_bytes,
    271                            LangIdResult *result) const {
    272   SAFTM_DCHECK(result) << "LangIdResult must not be null.";
    273   StringPiece text(data, num_bytes);
    274   pimpl_->FindLanguages(text, result);
    275 }
    276 
    277 bool LangId::is_valid() const { return pimpl_->is_valid(); }
    278 
    279 int LangId::GetModelVersion() const { return pimpl_->GetModelVersion(); }
    280 
    281 float LangId::GetFloatProperty(const string &property,
    282                                float default_value) const {
    283   return pimpl_->GetProperty<float, float>(property, default_value);
    284 }
    285 
    286 }  // namespace lang_id
    287 }  // namespace mobile
    288 }  // namespace nlp_saft
    289