Home | History | Annotate | Download | only in lang_id
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "lang_id/lang-id.h"
     18 
     19 #include <stdio.h>
     20 
     21 #include <algorithm>
     22 #include <limits>
     23 #include <memory>
     24 #include <string>
     25 #include <vector>
     26 
     27 #include "common/algorithm.h"
     28 #include "common/embedding-network-params-from-proto.h"
     29 #include "common/embedding-network.pb.h"
     30 #include "common/embedding-network.h"
     31 #include "common/feature-extractor.h"
     32 #include "common/file-utils.h"
     33 #include "common/list-of-strings.pb.h"
     34 #include "common/memory_image/in-memory-model-data.h"
     35 #include "common/mmap.h"
     36 #include "common/softmax.h"
     37 #include "common/task-context.h"
     38 #include "lang_id/custom-tokenizer.h"
     39 #include "lang_id/lang-id-brain-interface.h"
     40 #include "lang_id/language-identifier-features.h"
     41 #include "lang_id/light-sentence-features.h"
     42 #include "lang_id/light-sentence.h"
     43 #include "lang_id/relevant-script-feature.h"
     44 #include "util/base/logging.h"
     45 #include "util/base/macros.h"
     46 
     47 using ::libtextclassifier::nlp_core::file_utils::ParseProtoFromMemory;
     48 
     49 namespace libtextclassifier {
     50 namespace nlp_core {
     51 namespace lang_id {
     52 
     53 namespace {
     54 // Default value for the probability threshold; see comments for
     55 // LangId::SetProbabilityThreshold().
     56 static const float kDefaultProbabilityThreshold = 0.50;
     57 
     58 // Default value for min text size below which our model can't provide a
     59 // meaningful prediction.
     60 static const int kDefaultMinTextSizeInBytes = 20;
     61 
     62 // Initial value for the default language for LangId::FindLanguage().  The
     63 // default language can be changed (for an individual LangId object) using
     64 // LangId::SetDefaultLanguage().
     65 static const char kInitialDefaultLanguage[] = "";
     66 
     67 // Returns total number of bytes of the words from sentence, without the ^
     68 // (start-of-word) and $ (end-of-word) markers.  Note: "real text" means that
     69 // this ignores whitespace and punctuation characters from the original text.
     70 int GetRealTextSize(const LightSentence &sentence) {
     71   int total = 0;
     72   for (int i = 0; i < sentence.num_words(); ++i) {
     73     TC_DCHECK(!sentence.word(i).empty());
     74     TC_DCHECK_EQ('^', sentence.word(i).front());
     75     TC_DCHECK_EQ('$', sentence.word(i).back());
     76     total += sentence.word(i).size() - 2;
     77   }
     78   return total;
     79 }
     80 
     81 }  // namespace
     82 
     83 // Class that performs all work behind LangId.
     84 class LangIdImpl {
     85  public:
     86   explicit LangIdImpl(const std::string &filename) {
     87     // Using mmap as a fast way to read the model bytes.
     88     ScopedMmap scoped_mmap(filename);
     89     MmapHandle mmap_handle = scoped_mmap.handle();
     90     if (!mmap_handle.ok()) {
     91       TC_LOG(ERROR) << "Unable to read model bytes.";
     92       return;
     93     }
     94 
     95     Initialize(mmap_handle.to_stringpiece());
     96   }
     97 
     98   explicit LangIdImpl(int fd) {
     99     // Using mmap as a fast way to read the model bytes.
    100     ScopedMmap scoped_mmap(fd);
    101     MmapHandle mmap_handle = scoped_mmap.handle();
    102     if (!mmap_handle.ok()) {
    103       TC_LOG(ERROR) << "Unable to read model bytes.";
    104       return;
    105     }
    106 
    107     Initialize(mmap_handle.to_stringpiece());
    108   }
    109 
    110   LangIdImpl(const char *ptr, size_t length) {
    111     Initialize(StringPiece(ptr, length));
    112   }
    113 
    114   void Initialize(StringPiece model_bytes) {
    115     // Will set valid_ to true only on successful initialization.
    116     valid_ = false;
    117 
    118     // Make sure all relevant features are registered:
    119     ContinuousBagOfNgramsFunction::RegisterClass();
    120     RelevantScriptFeature::RegisterClass();
    121 
    122     // NOTE(salcianu): code below relies on the fact that the current features
    123     // do not rely on data from a TaskInput.  Otherwise, one would have to use
    124     // the more complex model registration mechanism, which requires more code.
    125     InMemoryModelData model_data(model_bytes);
    126     TaskContext context;
    127     if (!model_data.GetTaskSpec(context.mutable_spec())) {
    128       TC_LOG(ERROR) << "Unable to get model TaskSpec";
    129       return;
    130     }
    131 
    132     if (!ParseNetworkParams(model_data, &context)) {
    133       return;
    134     }
    135     if (!ParseListOfKnownLanguages(model_data, &context)) {
    136       return;
    137     }
    138 
    139     network_.reset(new EmbeddingNetwork(network_params_.get()));
    140     if (!network_->is_valid()) {
    141       return;
    142     }
    143 
    144     probability_threshold_ =
    145         context.Get("reliability_thresh", kDefaultProbabilityThreshold);
    146     min_text_size_in_bytes_ =
    147         context.Get("min_text_size_in_bytes", kDefaultMinTextSizeInBytes);
    148     version_ = context.Get("version", 0);
    149 
    150     if (!lang_id_brain_interface_.Init(&context)) {
    151       return;
    152     }
    153     valid_ = true;
    154   }
    155 
    156   void SetProbabilityThreshold(float threshold) {
    157     probability_threshold_ = threshold;
    158   }
    159 
    160   void SetDefaultLanguage(const std::string &lang) { default_language_ = lang; }
    161 
    162   std::string FindLanguage(const std::string &text) const {
    163     std::vector<float> scores = ScoreLanguages(text);
    164     if (scores.empty()) {
    165       return default_language_;
    166     }
    167 
    168     // Softmax label with max score.
    169     int label = GetArgMax(scores);
    170     float probability = scores[label];
    171     if (probability < probability_threshold_) {
    172       return default_language_;
    173     }
    174     return GetLanguageForSoftmaxLabel(label);
    175   }
    176 
    177   std::vector<std::pair<std::string, float>> FindLanguages(
    178       const std::string &text) const {
    179     std::vector<float> scores = ScoreLanguages(text);
    180 
    181     std::vector<std::pair<std::string, float>> result;
    182     for (int i = 0; i < scores.size(); i++) {
    183       result.push_back({GetLanguageForSoftmaxLabel(i), scores[i]});
    184     }
    185 
    186     // To avoid crashing clients that always expect at least one predicted
    187     // language, we promised (see doc for this method) that the result always
    188     // contains at least one element.
    189     if (result.empty()) {
    190       // We use a tiny probability, such that any client that uses a meaningful
    191       // probability threshold ignores this prediction.  We don't use 0.0f, to
    192       // avoid crashing clients that normalize the probabilities we return here.
    193       result.push_back({default_language_, 0.001f});
    194     }
    195     return result;
    196   }
    197 
    198   std::vector<float> ScoreLanguages(const std::string &text) const {
    199     if (!is_valid()) {
    200       return {};
    201     }
    202 
    203     // Create a Sentence storing the input text.
    204     LightSentence sentence;
    205     TokenizeTextForLangId(text, &sentence);
    206 
    207     if (GetRealTextSize(sentence) < min_text_size_in_bytes_) {
    208       return {};
    209     }
    210 
    211     // TODO(salcianu): reuse vector<FeatureVector>.
    212     std::vector<FeatureVector> features(
    213         lang_id_brain_interface_.NumEmbeddings());
    214     lang_id_brain_interface_.GetFeatures(&sentence, &features);
    215 
    216     // Predict language.
    217     EmbeddingNetwork::Vector scores;
    218     network_->ComputeFinalScores(features, &scores);
    219 
    220     return ComputeSoftmax(scores);
    221   }
    222 
    223   bool is_valid() const { return valid_; }
    224 
    225   int version() const { return version_; }
    226 
    227  private:
    228   // Returns name of the (in-memory) file for the indicated TaskInput from
    229   // context.
    230   static std::string GetInMemoryFileNameForTaskInput(
    231       const std::string &input_name, TaskContext *context) {
    232     TaskInput *task_input = context->GetInput(input_name);
    233     if (task_input->part_size() != 1) {
    234       TC_LOG(ERROR) << "TaskInput " << input_name << " has "
    235                     << task_input->part_size() << " parts";
    236       return "";
    237     }
    238     return task_input->part(0).file_pattern();
    239   }
    240 
    241   bool ParseNetworkParams(const InMemoryModelData &model_data,
    242                           TaskContext *context) {
    243     const std::string input_name = "language-identifier-network";
    244     const std::string input_file_name =
    245         GetInMemoryFileNameForTaskInput(input_name, context);
    246     if (input_file_name.empty()) {
    247       TC_LOG(ERROR) << "No input file name for TaskInput " << input_name;
    248       return false;
    249     }
    250     StringPiece bytes = model_data.GetBytesForInputFile(input_file_name);
    251     if (bytes.data() == nullptr) {
    252       TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name;
    253       return false;
    254     }
    255     std::unique_ptr<EmbeddingNetworkProto> proto(new EmbeddingNetworkProto());
    256     if (!ParseProtoFromMemory(bytes, proto.get())) {
    257       TC_LOG(ERROR) << "Unable to parse EmbeddingNetworkProto";
    258       return false;
    259     }
    260     network_params_.reset(
    261         new EmbeddingNetworkParamsFromProto(std::move(proto)));
    262     if (!network_params_->is_valid()) {
    263       TC_LOG(ERROR) << "EmbeddingNetworkParamsFromProto not valid";
    264       return false;
    265     }
    266     return true;
    267   }
    268 
    269   // Parses dictionary with known languages (i.e., field languages_) from a
    270   // TaskInput of context.  Note: that TaskInput should be a ListOfStrings proto
    271   // with a single element, the serialized form of a ListOfStrings.
    272   //
    273   bool ParseListOfKnownLanguages(const InMemoryModelData &model_data,
    274                                  TaskContext *context) {
    275     const std::string input_name = "language-name-id-map";
    276     const std::string input_file_name =
    277         GetInMemoryFileNameForTaskInput(input_name, context);
    278     if (input_file_name.empty()) {
    279       TC_LOG(ERROR) << "No input file name for TaskInput " << input_name;
    280       return false;
    281     }
    282     StringPiece bytes = model_data.GetBytesForInputFile(input_file_name);
    283     if (bytes.data() == nullptr) {
    284       TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name;
    285       return false;
    286     }
    287     ListOfStrings records;
    288     if (!ParseProtoFromMemory(bytes, &records)) {
    289       TC_LOG(ERROR) << "Unable to parse ListOfStrings from TaskInput "
    290                     << input_name;
    291       return false;
    292     }
    293     if (records.element_size() != 1) {
    294       TC_LOG(ERROR) << "Wrong number of records in TaskInput " << input_name
    295                     << " : " << records.element_size();
    296       return false;
    297     }
    298     if (!ParseProtoFromMemory(std::string(records.element(0)), &languages_)) {
    299       TC_LOG(ERROR) << "Unable to parse dictionary with known languages";
    300       return false;
    301     }
    302     return true;
    303   }
    304 
    305   // Returns language code for a softmax label.  See comments for languages_
    306   // field.  If label is out of range, returns default_language_.
    307   std::string GetLanguageForSoftmaxLabel(int label) const {
    308     if ((label >= 0) && (label < languages_.element_size())) {
    309       return languages_.element(label);
    310     } else {
    311       TC_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
    312                     << languages_.element_size() << ")";
    313       return default_language_;
    314     }
    315   }
    316 
    317   LangIdBrainInterface lang_id_brain_interface_;
    318 
    319   // Parameters for the neural network network_ (see below).
    320   std::unique_ptr<EmbeddingNetworkParamsFromProto> network_params_;
    321 
    322   // Neural network to use for scoring.
    323   std::unique_ptr<EmbeddingNetwork> network_;
    324 
    325   // True if this object is ready to perform language predictions.
    326   bool valid_;
    327 
    328   // Only predictions with a probability (confidence) above this threshold are
    329   // reported.  Otherwise, we report default_language_.
    330   float probability_threshold_ = kDefaultProbabilityThreshold;
    331 
    332   // Min size of the input text for our predictions to be meaningful.  Below
    333   // this threshold, the underlying model may report a wrong language and a high
    334   // confidence score.
    335   int min_text_size_in_bytes_ = kDefaultMinTextSizeInBytes;
    336 
    337   // Version of the model.
    338   int version_ = -1;
    339 
    340   // Known languages: softmax label i (an integer) means languages_.element(i)
    341   // (something like "en", "fr", "ru", etc).
    342   ListOfStrings languages_;
    343 
    344   // Language code to return in case of errors.
    345   std::string default_language_ = kInitialDefaultLanguage;
    346 
    347   TC_DISALLOW_COPY_AND_ASSIGN(LangIdImpl);
    348 };
    349 
    350 LangId::LangId(const std::string &filename) : pimpl_(new LangIdImpl(filename)) {
    351   if (!pimpl_->is_valid()) {
    352     TC_LOG(ERROR) << "Unable to construct a valid LangId based "
    353                   << "on the data from " << filename
    354                   << "; nothing should crash, but "
    355                   << "accuracy will be bad.";
    356   }
    357 }
    358 
    359 LangId::LangId(int fd) : pimpl_(new LangIdImpl(fd)) {
    360   if (!pimpl_->is_valid()) {
    361     TC_LOG(ERROR) << "Unable to construct a valid LangId based "
    362                   << "on the data from descriptor " << fd
    363                   << "; nothing should crash, "
    364                   << "but accuracy will be bad.";
    365   }
    366 }
    367 
    368 LangId::LangId(const char *ptr, size_t length)
    369     : pimpl_(new LangIdImpl(ptr, length)) {
    370   if (!pimpl_->is_valid()) {
    371     TC_LOG(ERROR) << "Unable to construct a valid LangId based "
    372                   << "on the memory region; nothing should crash, "
    373                   << "but accuracy will be bad.";
    374   }
    375 }
    376 
    377 LangId::~LangId() = default;
    378 
    379 void LangId::SetProbabilityThreshold(float threshold) {
    380   pimpl_->SetProbabilityThreshold(threshold);
    381 }
    382 
    383 void LangId::SetDefaultLanguage(const std::string &lang) {
    384   pimpl_->SetDefaultLanguage(lang);
    385 }
    386 
    387 std::string LangId::FindLanguage(const std::string &text) const {
    388   return pimpl_->FindLanguage(text);
    389 }
    390 
    391 std::vector<std::pair<std::string, float>> LangId::FindLanguages(
    392     const std::string &text) const {
    393   return pimpl_->FindLanguages(text);
    394 }
    395 
    396 bool LangId::is_valid() const { return pimpl_->is_valid(); }
    397 
    398 int LangId::version() const { return pimpl_->version(); }
    399 
    400 }  // namespace lang_id
    401 }  // namespace nlp_core
    402 }  // namespace libtextclassifier
    403