Home | History | Annotate | Download | only in lang_id
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "lang_id/language-identifier-features.h"
     18 
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "common/feature-extractor.h"
     23 #include "common/feature-types.h"
     24 #include "common/task-context.h"
     25 #include "util/hash/hash.h"
     26 #include "util/strings/utf8.h"
     27 
     28 namespace libtextclassifier {
     29 namespace nlp_core {
     30 namespace lang_id {
     31 
     32 bool ContinuousBagOfNgramsFunction::Setup(TaskContext *context) {
     33   // Parameters in the feature function descriptor.
     34   ngram_id_dimension_ = GetIntParameter("id_dim", 10000);
     35   ngram_size_ = GetIntParameter("size", 3);
     36 
     37   counts_.assign(ngram_id_dimension_, 0);
     38   return true;
     39 }
     40 
     41 bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) {
     42   set_feature_type(new NumericFeatureType(name(), ngram_id_dimension_));
     43   return true;
     44 }
     45 
     46 int ContinuousBagOfNgramsFunction::ComputeNgramCounts(
     47     const LightSentence &sentence) const {
     48   // Invariant 1: counts_.size() == ngram_id_dimension_.  Holds at the end of
     49   // the constructor.  After that, no method changes counts_.size().
     50   TC_DCHECK_EQ(counts_.size(), ngram_id_dimension_);
     51 
     52   // Invariant 2: the vector non_zero_count_indices_ is empty.  The vector
     53   // non_zero_count_indices_ is empty at construction time and gets emptied at
     54   // the end of each call to Evaluate().  Hence, this invariant holds at the
     55   // beginning of each run of Evaluate(), where the only call to this code takes
     56   // place.
     57   TC_DCHECK(non_zero_count_indices_.empty());
     58 
     59   int total_count = 0;
     60 
     61   for (int i = 0; i < sentence.num_words(); ++i) {
     62     const std::string &word = sentence.word(i);
     63     const char *const word_end = word.data() + word.size();
     64 
     65     // Set ngram_start at the start of the current token (word).
     66     const char *ngram_start = word.data();
     67 
     68     // Set ngram_end ngram_size UTF8 characters after ngram_start.  Note: each
     69     // UTF8 character contains between 1 and 4 bytes.
     70     const char *ngram_end = ngram_start;
     71     int num_utf8_chars = 0;
     72     do {
     73       ngram_end += GetNumBytesForNonZeroUTF8Char(ngram_end);
     74       num_utf8_chars++;
     75     } while ((num_utf8_chars < ngram_size_) && (ngram_end < word_end));
     76 
     77     if (num_utf8_chars < ngram_size_) {
     78       // Current token is so small, it does not contain a single ngram of
     79       // ngram_size UTF8 characters.  Not much we can do in this case ...
     80       continue;
     81     }
     82 
     83     // At this point, [ngram_start, ngram_end) is the first ngram of ngram_size
     84     // UTF8 characters from current token.
     85     while (true) {
     86       // Compute ngram_id: hash(ngram) % ngram_id_dimension
     87       int ngram_id =
     88           (Hash32WithDefaultSeed(ngram_start, ngram_end - ngram_start) %
     89            ngram_id_dimension_);
     90 
     91       // Use a reference to the actual count, such that we can both test whether
     92       // the count was 0 and increment it without perfoming two lookups.
     93       //
     94       // Due to the way we compute ngram_id, 0 <= ngram_id < ngram_id_dimension.
     95       // Hence, by Invariant 1 (above), the access counts_[ngram_id] is safe.
     96       int &ref_to_count_for_ngram = counts_[ngram_id];
     97       if (ref_to_count_for_ngram == 0) {
     98         non_zero_count_indices_.push_back(ngram_id);
     99       }
    100       ref_to_count_for_ngram++;
    101       total_count++;
    102       if (ngram_end >= word_end) {
    103         break;
    104       }
    105 
    106       // Advance both ngram_start and ngram_end by one UTF8 character.  This
    107       // way, the number of UTF8 characters between them remains constant
    108       // (ngram_size).
    109       ngram_start += GetNumBytesForNonZeroUTF8Char(ngram_start);
    110       ngram_end += GetNumBytesForNonZeroUTF8Char(ngram_end);
    111     }
    112   }  // end of loop over tokens.
    113 
    114   return total_count;
    115 }
    116 
    117 void ContinuousBagOfNgramsFunction::Evaluate(const WorkspaceSet &workspaces,
    118                                              const LightSentence &sentence,
    119                                              FeatureVector *result) const {
    120   // Find the char ngram counts.
    121   int total_count = ComputeNgramCounts(sentence);
    122 
    123   // Populate the feature vector.
    124   const float norm = static_cast<float>(total_count);
    125 
    126   for (int ngram_id : non_zero_count_indices_) {
    127     const float weight = counts_[ngram_id] / norm;
    128     FloatFeatureValue value(ngram_id, weight);
    129     result->add(feature_type(), value.discrete_value);
    130 
    131     // Clear up counts_, for the next invocation of Evaluate().
    132     counts_[ngram_id] = 0;
    133   }
    134 
    135   // Clear up non_zero_count_indices_, for the next invocation of Evaluate().
    136   non_zero_count_indices_.clear();
    137 }
    138 
    139 }  // namespace lang_id
    140 }  // namespace nlp_core
    141 }  // namespace libtextclassifier
    142