Home | History | Annotate | Download | only in features
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "lang_id/features/char-ngram-feature.h"
     18 
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "lang_id/common/fel/feature-types.h"
     23 #include "lang_id/common/fel/task-context.h"
     24 #include "lang_id/common/lite_base/logging.h"
     25 #include "lang_id/common/math/hash.h"
     26 #include "lang_id/common/utf8.h"
     27 
     28 namespace libtextclassifier3 {
     29 namespace mobile {
     30 namespace lang_id {
     31 
     32 bool ContinuousBagOfNgramsFunction::Setup(TaskContext *context) {
     33   // Parameters in the feature function descriptor.
     34   bool include_terminators = GetBoolParameter("include_terminators", false);
     35   if (!include_terminators) {
     36     SAFTM_LOG(ERROR) << "No support for include_terminators=true";
     37     return false;
     38   }
     39 
     40   bool include_spaces = GetBoolParameter("include_spaces", false);
     41   if (include_spaces) {
     42     SAFTM_LOG(ERROR) << "No support for include_spaces=true";
     43     return false;
     44   }
     45 
     46   bool use_equal_ngram_weight = GetBoolParameter("use_equal_weight", false);
     47   if (use_equal_ngram_weight) {
     48     SAFTM_LOG(ERROR) << "No support for use_equal_weight=true";
     49     return false;
     50   }
     51 
     52   ngram_id_dimension_ = GetIntParameter("id_dim", 10000);
     53   ngram_size_ = GetIntParameter("size", 3);
     54 
     55   counts_.assign(ngram_id_dimension_, 0);
     56   return true;
     57 }
     58 
     59 bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) {
     60   set_feature_type(new NumericFeatureType(name(), ngram_id_dimension_));
     61   return true;
     62 }
     63 
     64 int ContinuousBagOfNgramsFunction::ComputeNgramCounts(
     65     const LightSentence &sentence) const {
     66   SAFTM_CHECK_EQ(counts_.size(), ngram_id_dimension_);
     67   SAFTM_CHECK_EQ(non_zero_count_indices_.size(), 0);
     68 
     69   int total_count = 0;
     70 
     71   for (const string &word : sentence) {
     72     const char *const word_end = word.data() + word.size();
     73 
     74     // Set ngram_start at the start of the current token (word).
     75     const char *ngram_start = word.data();
     76 
     77     // Set ngram_end ngram_size UTF8 characters after ngram_start.  Note: each
     78     // UTF8 character contains between 1 and 4 bytes.
     79     const char *ngram_end = ngram_start;
     80     int num_utf8_chars = 0;
     81     do {
     82       ngram_end += utils::OneCharLen(ngram_end);
     83       num_utf8_chars++;
     84     } while ((num_utf8_chars < ngram_size_) && (ngram_end < word_end));
     85 
     86     if (num_utf8_chars < ngram_size_) {
     87       // Current token is so small, it does not contain a single ngram of
     88       // ngram_size UTF8 characters.  Not much we can do in this case ...
     89       continue;
     90     }
     91 
     92     // At this point, [ngram_start, ngram_end) is the first ngram of ngram_size
     93     // UTF8 characters from current token.
     94     while (true) {
     95       // Compute ngram id: hash(ngram) % ngram_id_dimension
     96       int ngram_id = (
     97           utils::Hash32WithDefaultSeed(ngram_start, ngram_end - ngram_start)
     98           % ngram_id_dimension_);
     99 
    100       // Use a reference to the actual count, such that we can both test whether
    101       // the count was 0 and increment it without perfoming two lookups.
    102       int &ref_to_count_for_ngram = counts_[ngram_id];
    103       if (ref_to_count_for_ngram == 0) {
    104         non_zero_count_indices_.push_back(ngram_id);
    105       }
    106       ref_to_count_for_ngram++;
    107       total_count++;
    108       if (ngram_end >= word_end) {
    109         break;
    110       }
    111 
    112       // Advance both ngram_start and ngram_end by one UTF8 character.  This
    113       // way, the number of UTF8 characters between them remains constant
    114       // (ngram_size).
    115       ngram_start += utils::OneCharLen(ngram_start);
    116       ngram_end += utils::OneCharLen(ngram_end);
    117     }
    118   }  // end of loop over tokens.
    119 
    120   return total_count;
    121 }
    122 
    123 void ContinuousBagOfNgramsFunction::Evaluate(const WorkspaceSet &workspaces,
    124                                              const LightSentence &sentence,
    125                                              FeatureVector *result) const {
    126   // NOTE: we use std::* constructs (instead of absl::Mutex & co) to simplify
    127   // porting to Android and to avoid pulling in absl (which increases our code
    128   // size).
    129   std::lock_guard<std::mutex> mlock(state_mutex_);
    130 
    131   // Find the char ngram counts.
    132   int total_count = ComputeNgramCounts(sentence);
    133 
    134   // Populate the feature vector.
    135   const float norm = static_cast<float>(total_count);
    136 
    137   // TODO(salcianu): explore treating dense vectors (i.e., many non-zero
    138   // elements) separately.
    139   for (int ngram_id : non_zero_count_indices_) {
    140     const float weight = counts_[ngram_id] / norm;
    141     FloatFeatureValue value(ngram_id, weight);
    142     result->add(feature_type(), value.discrete_value);
    143 
    144     // Clear up counts_, for the next invocation of Evaluate().
    145     counts_[ngram_id] = 0;
    146   }
    147 
    148   // Clear up non_zero_count_indices_, for the next invocation of Evaluate().
    149   non_zero_count_indices_.clear();
    150 }
    151 
    152 SAFTM_STATIC_REGISTRATION(ContinuousBagOfNgramsFunction);
    153 
    154 }  // namespace lang_id
    155 }  // namespace mobile
    156 }  // namespace nlp_saft
    157