1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "lang_id/lang-id.h" 18 19 #include <stdio.h> 20 21 #include <algorithm> 22 #include <limits> 23 #include <memory> 24 #include <string> 25 #include <vector> 26 27 #include "common/algorithm.h" 28 #include "common/embedding-network-params-from-proto.h" 29 #include "common/embedding-network.pb.h" 30 #include "common/embedding-network.h" 31 #include "common/feature-extractor.h" 32 #include "common/file-utils.h" 33 #include "common/list-of-strings.pb.h" 34 #include "common/memory_image/in-memory-model-data.h" 35 #include "common/mmap.h" 36 #include "common/softmax.h" 37 #include "common/task-context.h" 38 #include "lang_id/custom-tokenizer.h" 39 #include "lang_id/lang-id-brain-interface.h" 40 #include "lang_id/language-identifier-features.h" 41 #include "lang_id/light-sentence-features.h" 42 #include "lang_id/light-sentence.h" 43 #include "lang_id/relevant-script-feature.h" 44 #include "util/base/logging.h" 45 #include "util/base/macros.h" 46 47 using ::libtextclassifier::nlp_core::file_utils::ParseProtoFromMemory; 48 49 namespace libtextclassifier { 50 namespace nlp_core { 51 namespace lang_id { 52 53 namespace { 54 // Default value for the probability threshold; see comments for 55 // LangId::SetProbabilityThreshold(). 56 static const float kDefaultProbabilityThreshold = 0.50; 57 58 // Default value for min text size below which our model can't provide a 59 // meaningful prediction. 60 static const int kDefaultMinTextSizeInBytes = 20; 61 62 // Initial value for the default language for LangId::FindLanguage(). The 63 // default language can be changed (for an individual LangId object) using 64 // LangId::SetDefaultLanguage(). 65 static const char kInitialDefaultLanguage[] = ""; 66 67 // Returns total number of bytes of the words from sentence, without the ^ 68 // (start-of-word) and $ (end-of-word) markers. Note: "real text" means that 69 // this ignores whitespace and punctuation characters from the original text. 70 int GetRealTextSize(const LightSentence &sentence) { 71 int total = 0; 72 for (int i = 0; i < sentence.num_words(); ++i) { 73 TC_DCHECK(!sentence.word(i).empty()); 74 TC_DCHECK_EQ('^', sentence.word(i).front()); 75 TC_DCHECK_EQ('$', sentence.word(i).back()); 76 total += sentence.word(i).size() - 2; 77 } 78 return total; 79 } 80 81 } // namespace 82 83 // Class that performs all work behind LangId. 84 class LangIdImpl { 85 public: 86 explicit LangIdImpl(const std::string &filename) { 87 // Using mmap as a fast way to read the model bytes. 88 ScopedMmap scoped_mmap(filename); 89 MmapHandle mmap_handle = scoped_mmap.handle(); 90 if (!mmap_handle.ok()) { 91 TC_LOG(ERROR) << "Unable to read model bytes."; 92 return; 93 } 94 95 Initialize(mmap_handle.to_stringpiece()); 96 } 97 98 explicit LangIdImpl(int fd) { 99 // Using mmap as a fast way to read the model bytes. 100 ScopedMmap scoped_mmap(fd); 101 MmapHandle mmap_handle = scoped_mmap.handle(); 102 if (!mmap_handle.ok()) { 103 TC_LOG(ERROR) << "Unable to read model bytes."; 104 return; 105 } 106 107 Initialize(mmap_handle.to_stringpiece()); 108 } 109 110 LangIdImpl(const char *ptr, size_t length) { 111 Initialize(StringPiece(ptr, length)); 112 } 113 114 void Initialize(StringPiece model_bytes) { 115 // Will set valid_ to true only on successful initialization. 116 valid_ = false; 117 118 // Make sure all relevant features are registered: 119 ContinuousBagOfNgramsFunction::RegisterClass(); 120 RelevantScriptFeature::RegisterClass(); 121 122 // NOTE(salcianu): code below relies on the fact that the current features 123 // do not rely on data from a TaskInput. Otherwise, one would have to use 124 // the more complex model registration mechanism, which requires more code. 125 InMemoryModelData model_data(model_bytes); 126 TaskContext context; 127 if (!model_data.GetTaskSpec(context.mutable_spec())) { 128 TC_LOG(ERROR) << "Unable to get model TaskSpec"; 129 return; 130 } 131 132 if (!ParseNetworkParams(model_data, &context)) { 133 return; 134 } 135 if (!ParseListOfKnownLanguages(model_data, &context)) { 136 return; 137 } 138 139 network_.reset(new EmbeddingNetwork(network_params_.get())); 140 if (!network_->is_valid()) { 141 return; 142 } 143 144 probability_threshold_ = 145 context.Get("reliability_thresh", kDefaultProbabilityThreshold); 146 min_text_size_in_bytes_ = 147 context.Get("min_text_size_in_bytes", kDefaultMinTextSizeInBytes); 148 version_ = context.Get("version", 0); 149 150 if (!lang_id_brain_interface_.Init(&context)) { 151 return; 152 } 153 valid_ = true; 154 } 155 156 void SetProbabilityThreshold(float threshold) { 157 probability_threshold_ = threshold; 158 } 159 160 void SetDefaultLanguage(const std::string &lang) { default_language_ = lang; } 161 162 std::string FindLanguage(const std::string &text) const { 163 std::vector<float> scores = ScoreLanguages(text); 164 if (scores.empty()) { 165 return default_language_; 166 } 167 168 // Softmax label with max score. 169 int label = GetArgMax(scores); 170 float probability = scores[label]; 171 if (probability < probability_threshold_) { 172 return default_language_; 173 } 174 return GetLanguageForSoftmaxLabel(label); 175 } 176 177 std::vector<std::pair<std::string, float>> FindLanguages( 178 const std::string &text) const { 179 std::vector<float> scores = ScoreLanguages(text); 180 181 std::vector<std::pair<std::string, float>> result; 182 for (int i = 0; i < scores.size(); i++) { 183 result.push_back({GetLanguageForSoftmaxLabel(i), scores[i]}); 184 } 185 186 // To avoid crashing clients that always expect at least one predicted 187 // language, we promised (see doc for this method) that the result always 188 // contains at least one element. 189 if (result.empty()) { 190 // We use a tiny probability, such that any client that uses a meaningful 191 // probability threshold ignores this prediction. We don't use 0.0f, to 192 // avoid crashing clients that normalize the probabilities we return here. 193 result.push_back({default_language_, 0.001f}); 194 } 195 return result; 196 } 197 198 std::vector<float> ScoreLanguages(const std::string &text) const { 199 if (!is_valid()) { 200 return {}; 201 } 202 203 // Create a Sentence storing the input text. 204 LightSentence sentence; 205 TokenizeTextForLangId(text, &sentence); 206 207 if (GetRealTextSize(sentence) < min_text_size_in_bytes_) { 208 return {}; 209 } 210 211 // TODO(salcianu): reuse vector<FeatureVector>. 212 std::vector<FeatureVector> features( 213 lang_id_brain_interface_.NumEmbeddings()); 214 lang_id_brain_interface_.GetFeatures(&sentence, &features); 215 216 // Predict language. 217 EmbeddingNetwork::Vector scores; 218 network_->ComputeFinalScores(features, &scores); 219 220 return ComputeSoftmax(scores); 221 } 222 223 bool is_valid() const { return valid_; } 224 225 int version() const { return version_; } 226 227 private: 228 // Returns name of the (in-memory) file for the indicated TaskInput from 229 // context. 230 static std::string GetInMemoryFileNameForTaskInput( 231 const std::string &input_name, TaskContext *context) { 232 TaskInput *task_input = context->GetInput(input_name); 233 if (task_input->part_size() != 1) { 234 TC_LOG(ERROR) << "TaskInput " << input_name << " has " 235 << task_input->part_size() << " parts"; 236 return ""; 237 } 238 return task_input->part(0).file_pattern(); 239 } 240 241 bool ParseNetworkParams(const InMemoryModelData &model_data, 242 TaskContext *context) { 243 const std::string input_name = "language-identifier-network"; 244 const std::string input_file_name = 245 GetInMemoryFileNameForTaskInput(input_name, context); 246 if (input_file_name.empty()) { 247 TC_LOG(ERROR) << "No input file name for TaskInput " << input_name; 248 return false; 249 } 250 StringPiece bytes = model_data.GetBytesForInputFile(input_file_name); 251 if (bytes.data() == nullptr) { 252 TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name; 253 return false; 254 } 255 std::unique_ptr<EmbeddingNetworkProto> proto(new EmbeddingNetworkProto()); 256 if (!ParseProtoFromMemory(bytes, proto.get())) { 257 TC_LOG(ERROR) << "Unable to parse EmbeddingNetworkProto"; 258 return false; 259 } 260 network_params_.reset( 261 new EmbeddingNetworkParamsFromProto(std::move(proto))); 262 if (!network_params_->is_valid()) { 263 TC_LOG(ERROR) << "EmbeddingNetworkParamsFromProto not valid"; 264 return false; 265 } 266 return true; 267 } 268 269 // Parses dictionary with known languages (i.e., field languages_) from a 270 // TaskInput of context. Note: that TaskInput should be a ListOfStrings proto 271 // with a single element, the serialized form of a ListOfStrings. 272 // 273 bool ParseListOfKnownLanguages(const InMemoryModelData &model_data, 274 TaskContext *context) { 275 const std::string input_name = "language-name-id-map"; 276 const std::string input_file_name = 277 GetInMemoryFileNameForTaskInput(input_name, context); 278 if (input_file_name.empty()) { 279 TC_LOG(ERROR) << "No input file name for TaskInput " << input_name; 280 return false; 281 } 282 StringPiece bytes = model_data.GetBytesForInputFile(input_file_name); 283 if (bytes.data() == nullptr) { 284 TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name; 285 return false; 286 } 287 ListOfStrings records; 288 if (!ParseProtoFromMemory(bytes, &records)) { 289 TC_LOG(ERROR) << "Unable to parse ListOfStrings from TaskInput " 290 << input_name; 291 return false; 292 } 293 if (records.element_size() != 1) { 294 TC_LOG(ERROR) << "Wrong number of records in TaskInput " << input_name 295 << " : " << records.element_size(); 296 return false; 297 } 298 if (!ParseProtoFromMemory(std::string(records.element(0)), &languages_)) { 299 TC_LOG(ERROR) << "Unable to parse dictionary with known languages"; 300 return false; 301 } 302 return true; 303 } 304 305 // Returns language code for a softmax label. See comments for languages_ 306 // field. If label is out of range, returns default_language_. 307 std::string GetLanguageForSoftmaxLabel(int label) const { 308 if ((label >= 0) && (label < languages_.element_size())) { 309 return languages_.element(label); 310 } else { 311 TC_LOG(ERROR) << "Softmax label " << label << " outside range [0, " 312 << languages_.element_size() << ")"; 313 return default_language_; 314 } 315 } 316 317 LangIdBrainInterface lang_id_brain_interface_; 318 319 // Parameters for the neural network network_ (see below). 320 std::unique_ptr<EmbeddingNetworkParamsFromProto> network_params_; 321 322 // Neural network to use for scoring. 323 std::unique_ptr<EmbeddingNetwork> network_; 324 325 // True if this object is ready to perform language predictions. 326 bool valid_; 327 328 // Only predictions with a probability (confidence) above this threshold are 329 // reported. Otherwise, we report default_language_. 330 float probability_threshold_ = kDefaultProbabilityThreshold; 331 332 // Min size of the input text for our predictions to be meaningful. Below 333 // this threshold, the underlying model may report a wrong language and a high 334 // confidence score. 335 int min_text_size_in_bytes_ = kDefaultMinTextSizeInBytes; 336 337 // Version of the model. 338 int version_ = -1; 339 340 // Known languages: softmax label i (an integer) means languages_.element(i) 341 // (something like "en", "fr", "ru", etc). 342 ListOfStrings languages_; 343 344 // Language code to return in case of errors. 345 std::string default_language_ = kInitialDefaultLanguage; 346 347 TC_DISALLOW_COPY_AND_ASSIGN(LangIdImpl); 348 }; 349 350 LangId::LangId(const std::string &filename) : pimpl_(new LangIdImpl(filename)) { 351 if (!pimpl_->is_valid()) { 352 TC_LOG(ERROR) << "Unable to construct a valid LangId based " 353 << "on the data from " << filename 354 << "; nothing should crash, but " 355 << "accuracy will be bad."; 356 } 357 } 358 359 LangId::LangId(int fd) : pimpl_(new LangIdImpl(fd)) { 360 if (!pimpl_->is_valid()) { 361 TC_LOG(ERROR) << "Unable to construct a valid LangId based " 362 << "on the data from descriptor " << fd 363 << "; nothing should crash, " 364 << "but accuracy will be bad."; 365 } 366 } 367 368 LangId::LangId(const char *ptr, size_t length) 369 : pimpl_(new LangIdImpl(ptr, length)) { 370 if (!pimpl_->is_valid()) { 371 TC_LOG(ERROR) << "Unable to construct a valid LangId based " 372 << "on the memory region; nothing should crash, " 373 << "but accuracy will be bad."; 374 } 375 } 376 377 LangId::~LangId() = default; 378 379 void LangId::SetProbabilityThreshold(float threshold) { 380 pimpl_->SetProbabilityThreshold(threshold); 381 } 382 383 void LangId::SetDefaultLanguage(const std::string &lang) { 384 pimpl_->SetDefaultLanguage(lang); 385 } 386 387 std::string LangId::FindLanguage(const std::string &text) const { 388 return pimpl_->FindLanguage(text); 389 } 390 391 std::vector<std::pair<std::string, float>> LangId::FindLanguages( 392 const std::string &text) const { 393 return pimpl_->FindLanguages(text); 394 } 395 396 bool LangId::is_valid() const { return pimpl_->is_valid(); } 397 398 int LangId::version() const { return pimpl_->version(); } 399 400 } // namespace lang_id 401 } // namespace nlp_core 402 } // namespace libtextclassifier 403