Home | History | Annotate | Download | only in content
      1 /*
      2  * Copyright (C) 2014, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H
     18 #define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H
     19 
     20 #include <cstdio>
     21 #include <vector>
     22 
     23 #include "defines.h"
     24 #include "dictionary/property/word_attributes.h"
     25 #include "dictionary/structure/v4/content/language_model_dict_content_global_counters.h"
     26 #include "dictionary/structure/v4/content/probability_entry.h"
     27 #include "dictionary/structure/v4/content/terminal_position_lookup_table.h"
     28 #include "dictionary/structure/v4/ver4_dict_constants.h"
     29 #include "dictionary/utils/entry_counters.h"
     30 #include "dictionary/utils/trie_map.h"
     31 #include "utils/byte_array_view.h"
     32 #include "utils/int_array_view.h"
     33 
     34 namespace latinime {
     35 
     36 class HeaderPolicy;
     37 
     38 /**
     39  * Class representing language model.
     40  *
     41  * This class provides methods to get and store unigram/n-gram probability information and flags.
     42  */
     43 class LanguageModelDictContent {
     44  public:
     45     // Pair of word id and probability entry used for iteration.
     46     class WordIdAndProbabilityEntry {
     47      public:
     48         WordIdAndProbabilityEntry(const int wordId, const ProbabilityEntry &probabilityEntry)
     49                 : mWordId(wordId), mProbabilityEntry(probabilityEntry) {}
     50 
     51         int getWordId() const { return mWordId; }
     52         const ProbabilityEntry getProbabilityEntry() const { return mProbabilityEntry; }
     53 
     54      private:
     55         DISALLOW_DEFAULT_CONSTRUCTOR(WordIdAndProbabilityEntry);
     56         DISALLOW_ASSIGNMENT_OPERATOR(WordIdAndProbabilityEntry);
     57 
     58         const int mWordId;
     59         const ProbabilityEntry mProbabilityEntry;
     60     };
     61 
     62     // Iterator.
     63     class EntryIterator {
     64      public:
     65         EntryIterator(const TrieMap::TrieMapIterator &trieMapIterator,
     66                 const bool hasHistoricalInfo)
     67                 : mTrieMapIterator(trieMapIterator), mHasHistoricalInfo(hasHistoricalInfo) {}
     68 
     69         const WordIdAndProbabilityEntry operator*() const {
     70             const TrieMap::TrieMapIterator::IterationResult &result = *mTrieMapIterator;
     71             return WordIdAndProbabilityEntry(
     72                     result.key(), ProbabilityEntry::decode(result.value(), mHasHistoricalInfo));
     73         }
     74 
     75         bool operator!=(const EntryIterator &other) const {
     76             return mTrieMapIterator != other.mTrieMapIterator;
     77         }
     78 
     79         const EntryIterator &operator++() {
     80             ++mTrieMapIterator;
     81             return *this;
     82         }
     83 
     84      private:
     85         DISALLOW_DEFAULT_CONSTRUCTOR(EntryIterator);
     86         DISALLOW_ASSIGNMENT_OPERATOR(EntryIterator);
     87 
     88         TrieMap::TrieMapIterator mTrieMapIterator;
     89         const bool mHasHistoricalInfo;
     90     };
     91 
     92     // Class represents range to use range base for loops.
     93     class EntryRange {
     94      public:
     95         EntryRange(const TrieMap::TrieMapRange trieMapRange, const bool hasHistoricalInfo)
     96                 : mTrieMapRange(trieMapRange), mHasHistoricalInfo(hasHistoricalInfo) {}
     97 
     98         EntryIterator begin() const {
     99             return EntryIterator(mTrieMapRange.begin(), mHasHistoricalInfo);
    100         }
    101 
    102         EntryIterator end() const {
    103             return EntryIterator(mTrieMapRange.end(), mHasHistoricalInfo);
    104         }
    105 
    106      private:
    107         DISALLOW_DEFAULT_CONSTRUCTOR(EntryRange);
    108         DISALLOW_ASSIGNMENT_OPERATOR(EntryRange);
    109 
    110         const TrieMap::TrieMapRange mTrieMapRange;
    111         const bool mHasHistoricalInfo;
    112     };
    113 
    114     class DumppedFullEntryInfo {
    115      public:
    116         DumppedFullEntryInfo(std::vector<int> &prevWordIds, const int targetWordId,
    117                 const WordAttributes &wordAttributes, const ProbabilityEntry &probabilityEntry)
    118                 : mPrevWordIds(prevWordIds), mTargetWordId(targetWordId),
    119                   mWordAttributes(wordAttributes), mProbabilityEntry(probabilityEntry) {}
    120 
    121         const WordIdArrayView getPrevWordIds() const { return WordIdArrayView(mPrevWordIds); }
    122         int getTargetWordId() const { return mTargetWordId; }
    123         const WordAttributes &getWordAttributes() const { return mWordAttributes; }
    124         const ProbabilityEntry &getProbabilityEntry() const { return mProbabilityEntry; }
    125 
    126      private:
    127         DISALLOW_ASSIGNMENT_OPERATOR(DumppedFullEntryInfo);
    128 
    129         const std::vector<int> mPrevWordIds;
    130         const int mTargetWordId;
    131         const WordAttributes mWordAttributes;
    132         const ProbabilityEntry mProbabilityEntry;
    133     };
    134 
    135     LanguageModelDictContent(const ReadWriteByteArrayView *const buffers,
    136             const bool hasHistoricalInfo)
    137             : mTrieMap(buffers[TRIE_MAP_BUFFER_INDEX]),
    138               mGlobalCounters(buffers[GLOBAL_COUNTERS_BUFFER_INDEX]),
    139               mHasHistoricalInfo(hasHistoricalInfo) {}
    140 
    141     explicit LanguageModelDictContent(const bool hasHistoricalInfo)
    142             : mTrieMap(), mGlobalCounters(), mHasHistoricalInfo(hasHistoricalInfo) {}
    143 
    144     bool isNearSizeLimit() const {
    145         return mTrieMap.isNearSizeLimit() || mGlobalCounters.needsToHalveCounters();
    146     }
    147 
    148     bool save(FILE *const file) const;
    149 
    150     bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
    151             const LanguageModelDictContent *const originalContent);
    152 
    153     const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId,
    154             const bool mustMatchAllPrevWords, const HeaderPolicy *const headerPolicy) const;
    155 
    156     ProbabilityEntry getProbabilityEntry(const int wordId) const {
    157         return getNgramProbabilityEntry(WordIdArrayView(), wordId);
    158     }
    159 
    160     bool setProbabilityEntry(const int wordId, const ProbabilityEntry *const probabilityEntry) {
    161         mGlobalCounters.addToTotalCount(probabilityEntry->getHistoricalInfo()->getCount());
    162         return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry);
    163     }
    164 
    165     bool removeProbabilityEntry(const int wordId) {
    166         return removeNgramProbabilityEntry(WordIdArrayView(), wordId);
    167     }
    168 
    169     ProbabilityEntry getNgramProbabilityEntry(const WordIdArrayView prevWordIds,
    170             const int wordId) const;
    171 
    172     bool setNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId,
    173             const ProbabilityEntry *const probabilityEntry);
    174 
    175     bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
    176 
    177     EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
    178 
    179     std::vector<DumppedFullEntryInfo> exportAllNgramEntriesRelatedToWord(
    180             const HeaderPolicy *const headerPolicy, const int wordId) const;
    181 
    182     bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
    183             MutableEntryCounters *const outEntryCounters) {
    184         if (!updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
    185                 0 /* prevWordCount */, headerPolicy, mGlobalCounters.needsToHalveCounters(),
    186                 outEntryCounters)) {
    187             return false;
    188         }
    189         if (mGlobalCounters.needsToHalveCounters()) {
    190             mGlobalCounters.halveCounters();
    191         }
    192         return true;
    193     }
    194 
    195     // entryCounts should be created by updateAllProbabilityEntries.
    196     bool truncateEntries(const EntryCounts &currentEntryCounts, const EntryCounts &maxEntryCounts,
    197             const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);
    198 
    199     bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId,
    200             const bool isValid, const HistoricalInfo historicalInfo,
    201             const HeaderPolicy *const headerPolicy,
    202             MutableEntryCounters *const entryCountersToUpdate);
    203 
    204  private:
    205     DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
    206 
    207     class EntryInfoToTurncate {
    208      public:
    209         class Comparator {
    210          public:
    211             bool operator()(const EntryInfoToTurncate &left,
    212                     const EntryInfoToTurncate &right) const;
    213          private:
    214             DISALLOW_ASSIGNMENT_OPERATOR(Comparator);
    215         };
    216 
    217         EntryInfoToTurncate(const int priority, const int count, const int key,
    218                 const int prevWordCount, const int *const prevWordIds);
    219 
    220         int mPriority;
    221         // TODO: Remove.
    222         int mCount;
    223         int mKey;
    224         int mPrevWordCount;
    225         int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
    226 
    227      private:
    228         DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
    229     };
    230 
    231     static const int TRIE_MAP_BUFFER_INDEX;
    232     static const int GLOBAL_COUNTERS_BUFFER_INDEX;
    233 
    234     TrieMap mTrieMap;
    235     LanguageModelDictContentGlobalCounters mGlobalCounters;
    236     const bool mHasHistoricalInfo;
    237 
    238     bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
    239             const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex);
    240     int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
    241     int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
    242     bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount,
    243             const HeaderPolicy *const headerPolicy, const bool needsToHalveCounters,
    244             MutableEntryCounters *const outEntryCounters);
    245     bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
    246             const int maxEntryCount, const int targetLevel, int *const outEntryCount);
    247     bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
    248             const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
    249             std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
    250     const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
    251             const bool isValid, const HistoricalInfo historicalInfo,
    252             const HeaderPolicy *const headerPolicy) const;
    253     void exportAllNgramEntriesRelatedToWordInner(const HeaderPolicy *const headerPolicy,
    254             const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
    255             std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const;
    256 };
    257 } // namespace latinime
    258 #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
    259