1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "lang_id/language-identifier-features.h" 18 19 #include <utility> 20 #include <vector> 21 22 #include "common/feature-extractor.h" 23 #include "common/feature-types.h" 24 #include "common/task-context.h" 25 #include "util/hash/hash.h" 26 #include "util/strings/utf8.h" 27 28 namespace libtextclassifier { 29 namespace nlp_core { 30 namespace lang_id { 31 32 bool ContinuousBagOfNgramsFunction::Setup(TaskContext *context) { 33 // Parameters in the feature function descriptor. 34 ngram_id_dimension_ = GetIntParameter("id_dim", 10000); 35 ngram_size_ = GetIntParameter("size", 3); 36 37 counts_.assign(ngram_id_dimension_, 0); 38 return true; 39 } 40 41 bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) { 42 set_feature_type(new NumericFeatureType(name(), ngram_id_dimension_)); 43 return true; 44 } 45 46 int ContinuousBagOfNgramsFunction::ComputeNgramCounts( 47 const LightSentence &sentence) const { 48 // Invariant 1: counts_.size() == ngram_id_dimension_. Holds at the end of 49 // the constructor. After that, no method changes counts_.size(). 50 TC_DCHECK_EQ(counts_.size(), ngram_id_dimension_); 51 52 // Invariant 2: the vector non_zero_count_indices_ is empty. The vector 53 // non_zero_count_indices_ is empty at construction time and gets emptied at 54 // the end of each call to Evaluate(). Hence, this invariant holds at the 55 // beginning of each run of Evaluate(), where the only call to this code takes 56 // place. 57 TC_DCHECK(non_zero_count_indices_.empty()); 58 59 int total_count = 0; 60 61 for (int i = 0; i < sentence.num_words(); ++i) { 62 const std::string &word = sentence.word(i); 63 const char *const word_end = word.data() + word.size(); 64 65 // Set ngram_start at the start of the current token (word). 66 const char *ngram_start = word.data(); 67 68 // Set ngram_end ngram_size UTF8 characters after ngram_start. Note: each 69 // UTF8 character contains between 1 and 4 bytes. 70 const char *ngram_end = ngram_start; 71 int num_utf8_chars = 0; 72 do { 73 ngram_end += GetNumBytesForNonZeroUTF8Char(ngram_end); 74 num_utf8_chars++; 75 } while ((num_utf8_chars < ngram_size_) && (ngram_end < word_end)); 76 77 if (num_utf8_chars < ngram_size_) { 78 // Current token is so small, it does not contain a single ngram of 79 // ngram_size UTF8 characters. Not much we can do in this case ... 80 continue; 81 } 82 83 // At this point, [ngram_start, ngram_end) is the first ngram of ngram_size 84 // UTF8 characters from current token. 85 while (true) { 86 // Compute ngram_id: hash(ngram) % ngram_id_dimension 87 int ngram_id = 88 (Hash32WithDefaultSeed(ngram_start, ngram_end - ngram_start) % 89 ngram_id_dimension_); 90 91 // Use a reference to the actual count, such that we can both test whether 92 // the count was 0 and increment it without perfoming two lookups. 93 // 94 // Due to the way we compute ngram_id, 0 <= ngram_id < ngram_id_dimension. 95 // Hence, by Invariant 1 (above), the access counts_[ngram_id] is safe. 96 int &ref_to_count_for_ngram = counts_[ngram_id]; 97 if (ref_to_count_for_ngram == 0) { 98 non_zero_count_indices_.push_back(ngram_id); 99 } 100 ref_to_count_for_ngram++; 101 total_count++; 102 if (ngram_end >= word_end) { 103 break; 104 } 105 106 // Advance both ngram_start and ngram_end by one UTF8 character. This 107 // way, the number of UTF8 characters between them remains constant 108 // (ngram_size). 109 ngram_start += GetNumBytesForNonZeroUTF8Char(ngram_start); 110 ngram_end += GetNumBytesForNonZeroUTF8Char(ngram_end); 111 } 112 } // end of loop over tokens. 113 114 return total_count; 115 } 116 117 void ContinuousBagOfNgramsFunction::Evaluate(const WorkspaceSet &workspaces, 118 const LightSentence &sentence, 119 FeatureVector *result) const { 120 // Find the char ngram counts. 121 int total_count = ComputeNgramCounts(sentence); 122 123 // Populate the feature vector. 124 const float norm = static_cast<float>(total_count); 125 126 for (int ngram_id : non_zero_count_indices_) { 127 const float weight = counts_[ngram_id] / norm; 128 FloatFeatureValue value(ngram_id, weight); 129 result->add(feature_type(), value.discrete_value); 130 131 // Clear up counts_, for the next invocation of Evaluate(). 132 counts_[ngram_id] = 0; 133 } 134 135 // Clear up non_zero_count_indices_, for the next invocation of Evaluate(). 136 non_zero_count_indices_.clear(); 137 } 138 139 } // namespace lang_id 140 } // namespace nlp_core 141 } // namespace libtextclassifier 142