1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "lang_id/lang-id.h" 18 19 #include <stdio.h> 20 21 #include <algorithm> 22 #include <memory> 23 #include <string> 24 #include <unordered_map> 25 #include <vector> 26 27 #include "lang_id/common/embedding-feature-interface.h" 28 #include "lang_id/common/embedding-network-params.h" 29 #include "lang_id/common/embedding-network.h" 30 #include "lang_id/common/fel/feature-extractor.h" 31 #include "lang_id/common/lite_base/logging.h" 32 #include "lang_id/common/lite_strings/numbers.h" 33 #include "lang_id/common/lite_strings/str-split.h" 34 #include "lang_id/common/lite_strings/stringpiece.h" 35 #include "lang_id/common/math/algorithm.h" 36 #include "lang_id/common/math/softmax.h" 37 #include "lang_id/custom-tokenizer.h" 38 #include "lang_id/features/light-sentence-features.h" 39 #include "lang_id/light-sentence.h" 40 41 namespace libtextclassifier3 { 42 namespace mobile { 43 namespace lang_id { 44 45 namespace { 46 // Default value for the confidence threshold. If the confidence of the top 47 // prediction is below this threshold, then FindLanguage() returns 48 // LangId::kUnknownLanguageCode. Note: this is just a default value; if the 49 // TaskSpec from the model specifies a "reliability_thresh" parameter, then we 50 // use that value instead. Note: for legacy reasons, our code and comments use 51 // the terms "confidence", "probability" and "reliability" equivalently. 52 static const float kDefaultConfidenceThreshold = 0.50f; 53 } // namespace 54 55 // Class that performs all work behind LangId. 56 class LangIdImpl { 57 public: 58 explicit LangIdImpl(std::unique_ptr<ModelProvider> model_provider) 59 : model_provider_(std::move(model_provider)), 60 lang_id_brain_interface_("language_identifier") { 61 // Note: in the code below, we set valid_ to true only if all initialization 62 // steps completed successfully. Otherwise, we return early, leaving valid_ 63 // to its default value false. 64 if (!model_provider_ || !model_provider_->is_valid()) { 65 SAFTM_LOG(ERROR) << "Invalid model provider"; 66 return; 67 } 68 69 auto *nn_params = model_provider_->GetNnParams(); 70 if (!nn_params) { 71 SAFTM_LOG(ERROR) << "No NN params"; 72 return; 73 } 74 network_.reset(new EmbeddingNetwork(nn_params)); 75 76 languages_ = model_provider_->GetLanguages(); 77 if (languages_.empty()) { 78 SAFTM_LOG(ERROR) << "No known languages"; 79 return; 80 } 81 82 TaskContext context = *model_provider_->GetTaskContext(); 83 if (!Setup(&context)) { 84 SAFTM_LOG(ERROR) << "Unable to Setup() LangId"; 85 return; 86 } 87 if (!Init(&context)) { 88 SAFTM_LOG(ERROR) << "Unable to Init() LangId"; 89 return; 90 } 91 valid_ = true; 92 } 93 94 string FindLanguage(StringPiece text) const { 95 // NOTE: it would be wasteful to implement this method in terms of 96 // FindLanguages(). We just need the most likely language and its 97 // probability; no need to compute (and allocate) a vector of pairs for all 98 // languages, nor to compute probabilities for all non-top languages. 99 if (!is_valid()) { 100 return LangId::kUnknownLanguageCode; 101 } 102 103 std::vector<float> scores; 104 ComputeScores(text, &scores); 105 106 int prediction_id = GetArgMax(scores); 107 const string language = GetLanguageForSoftmaxLabel(prediction_id); 108 float probability = ComputeSoftmaxProbability(scores, prediction_id); 109 SAFTM_DLOG(INFO) << "Predicted " << language 110 << " with prob: " << probability << " for \"" << text 111 << "\""; 112 113 // Find confidence threshold for language. 114 float threshold = default_threshold_; 115 auto it = per_lang_thresholds_.find(language); 116 if (it != per_lang_thresholds_.end()) { 117 threshold = it->second; 118 } 119 if (probability < threshold) { 120 SAFTM_DLOG(INFO) << " below threshold => " 121 << LangId::kUnknownLanguageCode; 122 return LangId::kUnknownLanguageCode; 123 } 124 return language; 125 } 126 127 void FindLanguages(StringPiece text, LangIdResult *result) const { 128 if (result == nullptr) return; 129 130 result->predictions.clear(); 131 if (!is_valid()) { 132 result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1); 133 return; 134 } 135 136 std::vector<float> scores; 137 ComputeScores(text, &scores); 138 139 // Compute and sort softmax in descending order by probability and convert 140 // IDs to language code strings. When probabilities are equal, we sort by 141 // language code string in ascending order. 142 std::vector<float> softmax = ComputeSoftmax(scores); 143 144 for (int i = 0; i < softmax.size(); ++i) { 145 result->predictions.emplace_back(GetLanguageForSoftmaxLabel(i), 146 softmax[i]); 147 } 148 149 // Sort the resulting language predictions by probability in descending 150 // order. 151 std::sort(result->predictions.begin(), result->predictions.end(), 152 [](const std::pair<string, float> &a, 153 const std::pair<string, float> &b) { 154 if (a.second == b.second) { 155 return a.first.compare(b.first) < 0; 156 } else { 157 return a.second > b.second; 158 } 159 }); 160 } 161 162 bool is_valid() const { return valid_; } 163 164 int GetModelVersion() const { return model_version_; } 165 166 // Returns a property stored in the model file. 167 template <typename T, typename R> 168 R GetProperty(const string &property, T default_value) const { 169 return model_provider_->GetTaskContext()->Get(property, default_value); 170 } 171 172 private: 173 bool Setup(TaskContext *context) { 174 tokenizer_.Setup(context); 175 if (!lang_id_brain_interface_.SetupForProcessing(context)) return false; 176 default_threshold_ = 177 context->Get("reliability_thresh", kDefaultConfidenceThreshold); 178 179 // Parse task parameter "per_lang_reliability_thresholds", fill 180 // per_lang_thresholds_. 181 const string thresholds_str = 182 context->Get("per_lang_reliability_thresholds", ""); 183 std::vector<StringPiece> tokens = LiteStrSplit(thresholds_str, ','); 184 for (const auto &token : tokens) { 185 if (token.empty()) continue; 186 std::vector<StringPiece> parts = LiteStrSplit(token, '='); 187 float threshold = 0.0f; 188 if ((parts.size() == 2) && LiteAtof(parts[1], &threshold)) { 189 per_lang_thresholds_[string(parts[0])] = threshold; 190 } else { 191 SAFTM_LOG(ERROR) << "Broken token: \"" << token << "\""; 192 } 193 } 194 model_version_ = context->Get("model_version", model_version_); 195 return true; 196 } 197 198 bool Init(TaskContext *context) { 199 return lang_id_brain_interface_.InitForProcessing(context); 200 } 201 202 // Extracts features for |text|, runs them through the feed-forward neural 203 // network, and computes the output scores (activations from the last layer). 204 // These scores can be used to compute the softmax probabilities for our 205 // labels (in this case, the languages). 206 void ComputeScores(StringPiece text, std::vector<float> *scores) const { 207 // Create a Sentence storing the input text. 208 LightSentence sentence; 209 tokenizer_.Tokenize(text, &sentence); 210 211 std::vector<FeatureVector> features = 212 lang_id_brain_interface_.GetFeaturesNoCaching(&sentence); 213 214 // Run feed-forward neural network to compute scores. 215 network_->ComputeFinalScores(features, scores); 216 } 217 218 // Returns language code for a softmax label. See comments for languages_ 219 // field. If label is out of range, returns LangId::kUnknownLanguageCode. 220 string GetLanguageForSoftmaxLabel(int label) const { 221 if ((label >= 0) && (label < languages_.size())) { 222 return languages_[label]; 223 } else { 224 SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, " 225 << languages_.size() << ")"; 226 return LangId::kUnknownLanguageCode; 227 } 228 } 229 230 std::unique_ptr<ModelProvider> model_provider_; 231 232 TokenizerForLangId tokenizer_; 233 234 EmbeddingFeatureInterface<LightSentenceExtractor, LightSentence> 235 lang_id_brain_interface_; 236 237 // Neural network to use for scoring. 238 std::unique_ptr<EmbeddingNetwork> network_; 239 240 // True if this object is ready to perform language predictions. 241 bool valid_ = false; 242 243 // Only predictions with a probability (confidence) above this threshold are 244 // reported. Otherwise, we report LangId::kUnknownLanguageCode. 245 float default_threshold_ = kDefaultConfidenceThreshold; 246 247 std::unordered_map<string, float> per_lang_thresholds_; 248 249 // Recognized languages: softmax label i means languages_[i] (something like 250 // "en", "fr", "ru", etc). 251 std::vector<string> languages_; 252 253 // Version of the model used by this LangIdImpl object. Zero means that the 254 // model version could not be determined. 255 int model_version_ = 0; 256 }; 257 258 const char LangId::kUnknownLanguageCode[] = "und"; 259 260 LangId::LangId(std::unique_ptr<ModelProvider> model_provider) 261 : pimpl_(new LangIdImpl(std::move(model_provider))) {} 262 263 LangId::~LangId() = default; 264 265 string LangId::FindLanguage(const char *data, size_t num_bytes) const { 266 StringPiece text(data, num_bytes); 267 return pimpl_->FindLanguage(text); 268 } 269 270 void LangId::FindLanguages(const char *data, size_t num_bytes, 271 LangIdResult *result) const { 272 SAFTM_DCHECK(result) << "LangIdResult must not be null."; 273 StringPiece text(data, num_bytes); 274 pimpl_->FindLanguages(text, result); 275 } 276 277 bool LangId::is_valid() const { return pimpl_->is_valid(); } 278 279 int LangId::GetModelVersion() const { return pimpl_->GetModelVersion(); } 280 281 float LangId::GetFloatProperty(const string &property, 282 float default_value) const { 283 return pimpl_->GetProperty<float, float>(property, default_value); 284 } 285 286 } // namespace lang_id 287 } // namespace mobile 288 } // namespace nlp_saft 289