1 /* 2 * Copyright (C) 2014, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "dictionary/structure/v4/content/language_model_dict_content.h" 18 19 #include <algorithm> 20 #include <cstring> 21 22 #include "dictionary/structure/v4/content/dynamic_language_model_probability_utils.h" 23 #include "dictionary/utils/probability_utils.h" 24 #include "utils/ngram_utils.h" 25 26 namespace latinime { 27 28 const int LanguageModelDictContent::TRIE_MAP_BUFFER_INDEX = 0; 29 const int LanguageModelDictContent::GLOBAL_COUNTERS_BUFFER_INDEX = 1; 30 31 bool LanguageModelDictContent::save(FILE *const file) const { 32 return mTrieMap.save(file) && mGlobalCounters.save(file); 33 } 34 35 bool LanguageModelDictContent::runGC( 36 const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, 37 const LanguageModelDictContent *const originalContent) { 38 return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(), 39 0 /* nextLevelBitmapEntryIndex */); 40 } 41 42 const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds, 43 const int wordId, const bool mustMatchAllPrevWords, 44 const HeaderPolicy *const headerPolicy) const { 45 int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; 46 bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); 47 int maxPrevWordCount = 0; 48 for (size_t i = 0; i < prevWordIds.size(); ++i) { 49 const int nextBitmapEntryIndex = 50 mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex; 51 if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) { 52 break; 53 } 54 maxPrevWordCount = i + 1; 55 bitmapEntryIndices[i + 1] = nextBitmapEntryIndex; 56 } 57 58 const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId); 59 if (mHasHistoricalInfo && unigramProbabilityEntry.getHistoricalInfo()->getCount() == 0) { 60 // The word should be treated as a invalid word. 61 return WordAttributes(); 62 } 63 for (int i = maxPrevWordCount; i >= 0; --i) { 64 if (mustMatchAllPrevWords && prevWordIds.size() > static_cast<size_t>(i)) { 65 break; 66 } 67 const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]); 68 if (!result.mIsValid) { 69 continue; 70 } 71 const ProbabilityEntry probabilityEntry = 72 ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); 73 int probability = NOT_A_PROBABILITY; 74 if (mHasHistoricalInfo) { 75 const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); 76 int contextCount = 0; 77 if (i == 0) { 78 // unigram 79 contextCount = mGlobalCounters.getTotalCount(); 80 } else { 81 const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry( 82 prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]); 83 if (!prevWordProbabilityEntry.isValid()) { 84 continue; 85 } 86 if (prevWordProbabilityEntry.representsBeginningOfSentence() 87 && historicalInfo->getCount() == 1) { 88 // BoS ngram requires multiple contextCount. 89 continue; 90 } 91 contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount(); 92 } 93 const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(i + 1); 94 const float rawProbability = 95 DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts( 96 historicalInfo->getCount(), contextCount, ngramType); 97 const int encodedRawProbability = 98 ProbabilityUtils::encodeRawProbability(rawProbability); 99 const int decayedProbability = 100 DynamicLanguageModelProbabilityUtils::getDecayedProbability( 101 encodedRawProbability, *historicalInfo); 102 probability = DynamicLanguageModelProbabilityUtils::backoff( 103 decayedProbability, ngramType); 104 } else { 105 probability = probabilityEntry.getProbability(); 106 } 107 // TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in 108 // probabilityEntry. 109 return WordAttributes(probability, unigramProbabilityEntry.isBlacklisted(), 110 unigramProbabilityEntry.isNotAWord(), 111 unigramProbabilityEntry.isPossiblyOffensive()); 112 } 113 // Cannot find the word. 114 return WordAttributes(); 115 } 116 117 ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry( 118 const WordIdArrayView prevWordIds, const int wordId) const { 119 const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); 120 if (bitmapEntryIndex == TrieMap::INVALID_INDEX) { 121 return ProbabilityEntry(); 122 } 123 const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndex); 124 if (!result.mIsValid) { 125 // Not found. 126 return ProbabilityEntry(); 127 } 128 return ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); 129 } 130 131 bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds, 132 const int wordId, const ProbabilityEntry *const probabilityEntry) { 133 if (wordId == Ver4DictConstants::NOT_A_TERMINAL_ID) { 134 return false; 135 } 136 const int bitmapEntryIndex = createAndGetBitmapEntryIndex(prevWordIds); 137 if (bitmapEntryIndex == TrieMap::INVALID_INDEX) { 138 return false; 139 } 140 return mTrieMap.put(wordId, probabilityEntry->encode(mHasHistoricalInfo), bitmapEntryIndex); 141 } 142 143 bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, 144 const int wordId) { 145 const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); 146 if (bitmapEntryIndex == TrieMap::INVALID_INDEX) { 147 // Cannot find bitmap entry for the probability entry. The entry doesn't exist. 148 return false; 149 } 150 return mTrieMap.remove(wordId, bitmapEntryIndex); 151 } 152 153 LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEntries( 154 const WordIdArrayView prevWordIds) const { 155 const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); 156 return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo); 157 } 158 159 std::vector<LanguageModelDictContent::DumppedFullEntryInfo> 160 LanguageModelDictContent::exportAllNgramEntriesRelatedToWord( 161 const HeaderPolicy *const headerPolicy, const int wordId) const { 162 const TrieMap::Result result = mTrieMap.getRoot(wordId); 163 if (!result.mIsValid || result.mNextLevelBitmapEntryIndex == TrieMap::INVALID_INDEX) { 164 // The word doesn't have any related ngram entries. 165 return std::vector<DumppedFullEntryInfo>(); 166 } 167 std::vector<int> prevWordIds = { wordId }; 168 std::vector<DumppedFullEntryInfo> entries; 169 exportAllNgramEntriesRelatedToWordInner(headerPolicy, result.mNextLevelBitmapEntryIndex, 170 &prevWordIds, &entries); 171 return entries; 172 } 173 174 void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner( 175 const HeaderPolicy *const headerPolicy, const int bitmapEntryIndex, 176 std::vector<int> *const prevWordIds, 177 std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const { 178 for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { 179 const int wordId = entry.key(); 180 const ProbabilityEntry probabilityEntry = 181 ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); 182 if (probabilityEntry.isValid()) { 183 const WordAttributes wordAttributes = getWordAttributes( 184 WordIdArrayView(*prevWordIds), wordId, true /* mustMatchAllPrevWords */, 185 headerPolicy); 186 outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId, 187 wordAttributes, probabilityEntry); 188 } 189 if (entry.hasNextLevelMap()) { 190 prevWordIds->push_back(wordId); 191 exportAllNgramEntriesRelatedToWordInner(headerPolicy, 192 entry.getNextLevelBitmapEntryIndex(), prevWordIds, outBummpedFullEntryInfo); 193 prevWordIds->pop_back(); 194 } 195 } 196 } 197 198 bool LanguageModelDictContent::truncateEntries(const EntryCounts ¤tEntryCounts, 199 const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy, 200 MutableEntryCounters *const outEntryCounters) { 201 for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) { 202 const int totalWordCount = prevWordCount + 1; 203 const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(totalWordCount); 204 if (currentEntryCounts.getNgramCount(ngramType) 205 <= maxEntryCounts.getNgramCount(ngramType)) { 206 outEntryCounters->setNgramCount(ngramType, 207 currentEntryCounts.getNgramCount(ngramType)); 208 continue; 209 } 210 int entryCount = 0; 211 if (!turncateEntriesInSpecifiedLevel(headerPolicy, 212 maxEntryCounts.getNgramCount(ngramType), prevWordCount, &entryCount)) { 213 return false; 214 } 215 outEntryCounters->setNgramCount(ngramType, entryCount); 216 } 217 return true; 218 } 219 220 bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, 221 const int wordId, const bool isValid, const HistoricalInfo historicalInfo, 222 const HeaderPolicy *const headerPolicy, MutableEntryCounters *const entryCountersToUpdate) { 223 if (!mHasHistoricalInfo) { 224 AKLOGE("updateAllEntriesOnInputWord is called for dictionary without historical info."); 225 return false; 226 } 227 const ProbabilityEntry originalUnigramProbabilityEntry = getProbabilityEntry(wordId); 228 const ProbabilityEntry updatedUnigramProbabilityEntry = createUpdatedEntryFrom( 229 originalUnigramProbabilityEntry, isValid, historicalInfo, headerPolicy); 230 if (!setProbabilityEntry(wordId, &updatedUnigramProbabilityEntry)) { 231 return false; 232 } 233 mGlobalCounters.incrementTotalCount(); 234 mGlobalCounters.updateMaxValueOfCounters( 235 updatedUnigramProbabilityEntry.getHistoricalInfo()->getCount()); 236 for (size_t i = 0; i < prevWordIds.size(); ++i) { 237 if (prevWordIds[i] == NOT_A_WORD_ID) { 238 break; 239 } 240 // TODO: Optimize this code. 241 const WordIdArrayView limitedPrevWordIds = prevWordIds.limit(i + 1); 242 const ProbabilityEntry originalNgramProbabilityEntry = getNgramProbabilityEntry( 243 limitedPrevWordIds, wordId); 244 const ProbabilityEntry updatedNgramProbabilityEntry = createUpdatedEntryFrom( 245 originalNgramProbabilityEntry, isValid, historicalInfo, headerPolicy); 246 if (!setNgramProbabilityEntry(limitedPrevWordIds, wordId, &updatedNgramProbabilityEntry)) { 247 return false; 248 } 249 mGlobalCounters.updateMaxValueOfCounters( 250 updatedNgramProbabilityEntry.getHistoricalInfo()->getCount()); 251 if (!originalNgramProbabilityEntry.isValid()) { 252 // (i + 2) words are used in total because the prevWords consists of (i + 1) words when 253 // looking at its i-th element. 254 entryCountersToUpdate->incrementNgramCount( 255 NgramUtils::getNgramTypeFromWordCount(i + 2)); 256 } 257 } 258 return true; 259 } 260 261 const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom( 262 const ProbabilityEntry &originalProbabilityEntry, const bool isValid, 263 const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const { 264 const HistoricalInfo updatedHistoricalInfo = HistoricalInfo(historicalInfo.getTimestamp(), 265 0 /* level */, originalProbabilityEntry.getHistoricalInfo()->getCount() 266 + historicalInfo.getCount()); 267 if (originalProbabilityEntry.isValid()) { 268 return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo); 269 } else { 270 return ProbabilityEntry(0 /* flags */, &updatedHistoricalInfo); 271 } 272 } 273 274 bool LanguageModelDictContent::runGCInner( 275 const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, 276 const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex) { 277 for (auto &entry : trieMapRange) { 278 const auto it = terminalIdMap->find(entry.key()); 279 if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) { 280 // The word has been removed. 281 continue; 282 } 283 if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) { 284 return false; 285 } 286 if (entry.hasNextLevelMap()) { 287 if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(), 288 mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex))) { 289 return false; 290 } 291 } 292 } 293 return true; 294 } 295 296 int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) { 297 int lastBitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex(); 298 for (const int wordId : prevWordIds) { 299 const TrieMap::Result result = mTrieMap.get(wordId, lastBitmapEntryIndex); 300 if (result.mIsValid && result.mNextLevelBitmapEntryIndex != TrieMap::INVALID_INDEX) { 301 lastBitmapEntryIndex = result.mNextLevelBitmapEntryIndex; 302 continue; 303 } 304 if (!result.mIsValid) { 305 if (!mTrieMap.put(wordId, ProbabilityEntry().encode(mHasHistoricalInfo), 306 lastBitmapEntryIndex)) { 307 AKLOGE("Failed to update trie map. wordId: %d, lastBitmapEntryIndex %d", wordId, 308 lastBitmapEntryIndex); 309 return TrieMap::INVALID_INDEX; 310 } 311 } 312 lastBitmapEntryIndex = mTrieMap.getNextLevelBitmapEntryIndex(wordId, 313 lastBitmapEntryIndex); 314 } 315 return lastBitmapEntryIndex; 316 } 317 318 int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const { 319 int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex(); 320 for (const int wordId : prevWordIds) { 321 const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndex); 322 if (!result.mIsValid) { 323 return TrieMap::INVALID_INDEX; 324 } 325 bitmapEntryIndex = result.mNextLevelBitmapEntryIndex; 326 } 327 return bitmapEntryIndex; 328 } 329 330 bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, 331 const int prevWordCount, const HeaderPolicy *const headerPolicy, 332 const bool needsToHalveCounters, MutableEntryCounters *const outEntryCounters) { 333 for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { 334 if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { 335 AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", 336 prevWordCount, MAX_PREV_WORD_COUNT_FOR_N_GRAM); 337 return false; 338 } 339 const ProbabilityEntry probabilityEntry = 340 ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); 341 if (prevWordCount > 0 && probabilityEntry.isValid() 342 && !mTrieMap.getRoot(entry.key()).mIsValid) { 343 // The entry is related to a word that has been removed. Remove the entry. 344 if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { 345 return false; 346 } 347 continue; 348 } 349 if (mHasHistoricalInfo && probabilityEntry.isValid()) { 350 const HistoricalInfo *originalHistoricalInfo = probabilityEntry.getHistoricalInfo(); 351 if (DynamicLanguageModelProbabilityUtils::shouldRemoveEntryDuringGC( 352 *originalHistoricalInfo)) { 353 // Remove the entry. 354 if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { 355 return false; 356 } 357 continue; 358 } 359 if (needsToHalveCounters) { 360 const int updatedCount = originalHistoricalInfo->getCount() / 2; 361 if (updatedCount == 0) { 362 // Remove the entry. 363 if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { 364 return false; 365 } 366 continue; 367 } 368 const HistoricalInfo historicalInfoToSave(originalHistoricalInfo->getTimestamp(), 369 originalHistoricalInfo->getLevel(), updatedCount); 370 const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), 371 &historicalInfoToSave); 372 if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), 373 bitmapEntryIndex)) { 374 return false; 375 } 376 } 377 } 378 outEntryCounters->incrementNgramCount( 379 NgramUtils::getNgramTypeFromWordCount(prevWordCount + 1)); 380 if (!entry.hasNextLevelMap()) { 381 continue; 382 } 383 if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), 384 prevWordCount + 1, headerPolicy, needsToHalveCounters, outEntryCounters)) { 385 return false; 386 } 387 } 388 return true; 389 } 390 391 bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel( 392 const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel, 393 int *const outEntryCount) { 394 std::vector<int> prevWordIds; 395 std::vector<EntryInfoToTurncate> entryInfoVector; 396 if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(), 397 &prevWordIds, &entryInfoVector)) { 398 return false; 399 } 400 if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) { 401 *outEntryCount = static_cast<int>(entryInfoVector.size()); 402 return true; 403 } 404 *outEntryCount = maxEntryCount; 405 const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount; 406 std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove, 407 entryInfoVector.end(), 408 EntryInfoToTurncate::Comparator()); 409 for (int i = 0; i < entryCountToRemove; ++i) { 410 const EntryInfoToTurncate &entryInfo = entryInfoVector[i]; 411 if (!removeNgramProbabilityEntry( 412 WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), 413 entryInfo.mKey)) { 414 return false; 415 } 416 } 417 return true; 418 } 419 420 bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy, 421 const int targetLevel, const int bitmapEntryIndex, std::vector<int> *const prevWordIds, 422 std::vector<EntryInfoToTurncate> *const outEntryInfo) const { 423 const int prevWordCount = prevWordIds->size(); 424 for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { 425 if (prevWordCount < targetLevel) { 426 if (!entry.hasNextLevelMap()) { 427 continue; 428 } 429 prevWordIds->push_back(entry.key()); 430 if (!getEntryInfo(headerPolicy, targetLevel, entry.getNextLevelBitmapEntryIndex(), 431 prevWordIds, outEntryInfo)) { 432 return false; 433 } 434 prevWordIds->pop_back(); 435 continue; 436 } 437 const ProbabilityEntry probabilityEntry = 438 ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); 439 const int priority = mHasHistoricalInfo 440 ? DynamicLanguageModelProbabilityUtils::getPriorityToPreventFromEviction( 441 *probabilityEntry.getHistoricalInfo()) 442 : probabilityEntry.getProbability(); 443 outEntryInfo->emplace_back(priority, probabilityEntry.getHistoricalInfo()->getCount(), 444 entry.key(), targetLevel, prevWordIds->data()); 445 } 446 return true; 447 } 448 449 bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( 450 const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const { 451 if (left.mPriority != right.mPriority) { 452 return left.mPriority < right.mPriority; 453 } 454 if (left.mCount != right.mCount) { 455 return left.mCount < right.mCount; 456 } 457 if (left.mKey != right.mKey) { 458 return left.mKey < right.mKey; 459 } 460 if (left.mPrevWordCount != right.mPrevWordCount) { 461 return left.mPrevWordCount > right.mPrevWordCount; 462 } 463 for (int i = 0; i < left.mPrevWordCount; ++i) { 464 if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) { 465 return left.mPrevWordIds[i] < right.mPrevWordIds[i]; 466 } 467 } 468 // left and rigth represent the same entry. 469 return false; 470 } 471 472 LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int priority, 473 const int count, const int key, const int prevWordCount, const int *const prevWordIds) 474 : mPriority(priority), mCount(count), mKey(key), mPrevWordCount(prevWordCount) { 475 memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0])); 476 } 477 478 } // namespace latinime 479