1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "lang_id/features/char-ngram-feature.h" 18 19 #include <utility> 20 #include <vector> 21 22 #include "lang_id/common/fel/feature-types.h" 23 #include "lang_id/common/fel/task-context.h" 24 #include "lang_id/common/lite_base/logging.h" 25 #include "lang_id/common/math/hash.h" 26 #include "lang_id/common/utf8.h" 27 28 namespace libtextclassifier3 { 29 namespace mobile { 30 namespace lang_id { 31 32 bool ContinuousBagOfNgramsFunction::Setup(TaskContext *context) { 33 // Parameters in the feature function descriptor. 34 bool include_terminators = GetBoolParameter("include_terminators", false); 35 if (!include_terminators) { 36 SAFTM_LOG(ERROR) << "No support for include_terminators=true"; 37 return false; 38 } 39 40 bool include_spaces = GetBoolParameter("include_spaces", false); 41 if (include_spaces) { 42 SAFTM_LOG(ERROR) << "No support for include_spaces=true"; 43 return false; 44 } 45 46 bool use_equal_ngram_weight = GetBoolParameter("use_equal_weight", false); 47 if (use_equal_ngram_weight) { 48 SAFTM_LOG(ERROR) << "No support for use_equal_weight=true"; 49 return false; 50 } 51 52 ngram_id_dimension_ = GetIntParameter("id_dim", 10000); 53 ngram_size_ = GetIntParameter("size", 3); 54 55 counts_.assign(ngram_id_dimension_, 0); 56 return true; 57 } 58 59 bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) { 60 set_feature_type(new NumericFeatureType(name(), ngram_id_dimension_)); 61 return true; 62 } 63 64 int ContinuousBagOfNgramsFunction::ComputeNgramCounts( 65 const LightSentence &sentence) const { 66 SAFTM_CHECK_EQ(counts_.size(), ngram_id_dimension_); 67 SAFTM_CHECK_EQ(non_zero_count_indices_.size(), 0); 68 69 int total_count = 0; 70 71 for (const string &word : sentence) { 72 const char *const word_end = word.data() + word.size(); 73 74 // Set ngram_start at the start of the current token (word). 75 const char *ngram_start = word.data(); 76 77 // Set ngram_end ngram_size UTF8 characters after ngram_start. Note: each 78 // UTF8 character contains between 1 and 4 bytes. 79 const char *ngram_end = ngram_start; 80 int num_utf8_chars = 0; 81 do { 82 ngram_end += utils::OneCharLen(ngram_end); 83 num_utf8_chars++; 84 } while ((num_utf8_chars < ngram_size_) && (ngram_end < word_end)); 85 86 if (num_utf8_chars < ngram_size_) { 87 // Current token is so small, it does not contain a single ngram of 88 // ngram_size UTF8 characters. Not much we can do in this case ... 89 continue; 90 } 91 92 // At this point, [ngram_start, ngram_end) is the first ngram of ngram_size 93 // UTF8 characters from current token. 94 while (true) { 95 // Compute ngram id: hash(ngram) % ngram_id_dimension 96 int ngram_id = ( 97 utils::Hash32WithDefaultSeed(ngram_start, ngram_end - ngram_start) 98 % ngram_id_dimension_); 99 100 // Use a reference to the actual count, such that we can both test whether 101 // the count was 0 and increment it without perfoming two lookups. 102 int &ref_to_count_for_ngram = counts_[ngram_id]; 103 if (ref_to_count_for_ngram == 0) { 104 non_zero_count_indices_.push_back(ngram_id); 105 } 106 ref_to_count_for_ngram++; 107 total_count++; 108 if (ngram_end >= word_end) { 109 break; 110 } 111 112 // Advance both ngram_start and ngram_end by one UTF8 character. This 113 // way, the number of UTF8 characters between them remains constant 114 // (ngram_size). 115 ngram_start += utils::OneCharLen(ngram_start); 116 ngram_end += utils::OneCharLen(ngram_end); 117 } 118 } // end of loop over tokens. 119 120 return total_count; 121 } 122 123 void ContinuousBagOfNgramsFunction::Evaluate(const WorkspaceSet &workspaces, 124 const LightSentence &sentence, 125 FeatureVector *result) const { 126 // NOTE: we use std::* constructs (instead of absl::Mutex & co) to simplify 127 // porting to Android and to avoid pulling in absl (which increases our code 128 // size). 129 std::lock_guard<std::mutex> mlock(state_mutex_); 130 131 // Find the char ngram counts. 132 int total_count = ComputeNgramCounts(sentence); 133 134 // Populate the feature vector. 135 const float norm = static_cast<float>(total_count); 136 137 // TODO(salcianu): explore treating dense vectors (i.e., many non-zero 138 // elements) separately. 139 for (int ngram_id : non_zero_count_indices_) { 140 const float weight = counts_[ngram_id] / norm; 141 FloatFeatureValue value(ngram_id, weight); 142 result->add(feature_type(), value.discrete_value); 143 144 // Clear up counts_, for the next invocation of Evaluate(). 145 counts_[ngram_id] = 0; 146 } 147 148 // Clear up non_zero_count_indices_, for the next invocation of Evaluate(). 149 non_zero_count_indices_.clear(); 150 } 151 152 SAFTM_STATIC_REGISTRATION(ContinuousBagOfNgramsFunction); 153 154 } // namespace lang_id 155 } // namespace mobile 156 } // namespace nlp_saft 157