Home | History | Annotate | Download | only in dicnode
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LATINIME_DIC_NODE_H
     18 #define LATINIME_DIC_NODE_H
     19 
     20 #include "defines.h"
     21 #include "suggest/core/dicnode/dic_node_profiler.h"
     22 #include "suggest/core/dicnode/dic_node_utils.h"
     23 #include "suggest/core/dicnode/internal/dic_node_state.h"
     24 #include "suggest/core/dicnode/internal/dic_node_properties.h"
     25 #include "suggest/core/dictionary/digraph_utils.h"
     26 #include "suggest/core/dictionary/error_type_utils.h"
     27 #include "suggest/core/layout/proximity_info_state.h"
     28 #include "utils/char_utils.h"
     29 #include "utils/int_array_view.h"
     30 
     31 #if DEBUG_DICT
     32 #define LOGI_SHOW_ADD_COST_PROP \
     33         do { \
     34             char charBuf[50]; \
     35             INTS_TO_CHARS(getOutputWordBuf(), getNodeCodePointCount(), charBuf, NELEMS(charBuf)); \
     36             AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \
     37                     __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \
     38                     getInputIndex(0), getNormalizedCompoundDistance(), charBuf); \
     39         } while (0)
     40 #define DUMP_WORD_AND_SCORE(header) \
     41         do { \
     42             char charBuf[50]; \
     43             INTS_TO_CHARS(getOutputWordBuf(), \
     44                     getNodeCodePointCount() \
     45                             + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength(), \
     46                     charBuf, NELEMS(charBuf)); \
     47             AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %d, %5f,", header, \
     48                     getSpatialDistanceForScoring(), \
     49                     mDicNodeState.mDicNodeStateScoring.getLanguageDistance(), \
     50                     getNormalizedCompoundDistance(), getRawLength(), charBuf, \
     51                     getInputIndex(0), getNormalizedCompoundDistanceAfterFirstWord()); \
     52         } while (0)
     53 #else
     54 #define LOGI_SHOW_ADD_COST_PROP
     55 #define DUMP_WORD_AND_SCORE(header)
     56 #endif
     57 
     58 namespace latinime {
     59 
     60 // This struct is purely a bucket to return values. No instances of this struct should be kept.
     61 struct DicNode_InputStateG {
     62     DicNode_InputStateG()
     63             : mNeedsToUpdateInputStateG(false), mPointerId(0), mInputIndex(0),
     64               mPrevCodePoint(0), mTerminalDiffCost(0.0f), mRawLength(0.0f),
     65               mDoubleLetterLevel(NOT_A_DOUBLE_LETTER) {}
     66 
     67     bool mNeedsToUpdateInputStateG;
     68     int mPointerId;
     69     int16_t mInputIndex;
     70     int mPrevCodePoint;
     71     float mTerminalDiffCost;
     72     float mRawLength;
     73     DoubleLetterLevel mDoubleLetterLevel;
     74 };
     75 
     76 class DicNode {
     77     // Caveat: We define Weighting as a friend class of DicNode to let Weighting change
     78     // the distance of DicNode.
     79     // Caution!!! In general, we avoid using the "friend" access modifier.
     80     // This is an exception to explicitly hide DicNode::addCost() from all classes but Weighting.
     81     friend class Weighting;
     82 
     83  public:
     84 #if DEBUG_DICT
     85     DicNodeProfiler mProfiler;
     86 #endif
     87 
     88     AK_FORCE_INLINE DicNode()
     89             :
     90 #if DEBUG_DICT
     91               mProfiler(),
     92 #endif
     93               mDicNodeProperties(), mDicNodeState(), mIsCachedForNextSuggestion(false) {}
     94 
     95     DicNode(const DicNode &dicNode);
     96     DicNode &operator=(const DicNode &dicNode);
     97     ~DicNode() {}
     98 
     99     // Init for copy
    100     void initByCopy(const DicNode *const dicNode) {
    101         mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
    102         mDicNodeProperties.initByCopy(&dicNode->mDicNodeProperties);
    103         mDicNodeState.initByCopy(&dicNode->mDicNodeState);
    104         PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
    105     }
    106 
    107     // Init for root with prevWordIds which is used for n-gram
    108     void initAsRoot(const int rootPtNodeArrayPos, const WordIdArrayView prevWordIds) {
    109         mIsCachedForNextSuggestion = false;
    110         mDicNodeProperties.init(rootPtNodeArrayPos, prevWordIds);
    111         mDicNodeState.init();
    112         PROF_NODE_RESET(mProfiler);
    113     }
    114 
    115     // Init for root with previous word
    116     void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) {
    117         mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
    118         WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> newPrevWordIds;
    119         newPrevWordIds[0] = dicNode->mDicNodeProperties.getWordId();
    120         dicNode->getPrevWordIds().limit(newPrevWordIds.size() - 1)
    121                 .copyToArray(&newPrevWordIds, 1 /* offset */);
    122         mDicNodeProperties.init(rootPtNodeArrayPos, WordIdArrayView::fromArray(newPrevWordIds));
    123         mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState,
    124                 dicNode->mDicNodeProperties.getDepth());
    125         PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
    126     }
    127 
    128     void initAsPassingChild(const DicNode *parentDicNode) {
    129         mIsCachedForNextSuggestion = parentDicNode->mIsCachedForNextSuggestion;
    130         const int codePoint =
    131                 parentDicNode->mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(
    132                             parentDicNode->getNodeCodePointCount());
    133         mDicNodeProperties.init(&parentDicNode->mDicNodeProperties, codePoint);
    134         mDicNodeState.initByCopy(&parentDicNode->mDicNodeState);
    135         PROF_NODE_COPY(&parentDicNode->mProfiler, mProfiler);
    136     }
    137 
    138     void initAsChild(const DicNode *const dicNode, const int childrenPtNodeArrayPos,
    139             const int wordId, const CodePointArrayView mergedCodePoints) {
    140         uint16_t newDepth = static_cast<uint16_t>(dicNode->getNodeCodePointCount() + 1);
    141         mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
    142         const uint16_t newLeavingDepth = static_cast<uint16_t>(
    143                 dicNode->mDicNodeProperties.getLeavingDepth() + mergedCodePoints.size());
    144         mDicNodeProperties.init(childrenPtNodeArrayPos, mergedCodePoints[0],
    145                 wordId, newDepth, newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordIds());
    146         mDicNodeState.init(&dicNode->mDicNodeState, mergedCodePoints.size(),
    147                 mergedCodePoints.data());
    148         PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
    149     }
    150 
    151     bool isRoot() const {
    152         return getNodeCodePointCount() == 0;
    153     }
    154 
    155     bool hasChildren() const {
    156         return mDicNodeProperties.hasChildren();
    157     }
    158 
    159     bool isLeavingNode() const {
    160         ASSERT(getNodeCodePointCount() <= mDicNodeProperties.getLeavingDepth());
    161         return getNodeCodePointCount() == mDicNodeProperties.getLeavingDepth();
    162     }
    163 
    164     AK_FORCE_INLINE bool isFirstLetter() const {
    165         return getNodeCodePointCount() == 1;
    166     }
    167 
    168     bool isCached() const {
    169         return mIsCachedForNextSuggestion;
    170     }
    171 
    172     void setCached() {
    173         mIsCachedForNextSuggestion = true;
    174     }
    175 
    176     // Check if the current word and the previous word can be considered as a valid multiple word
    177     // suggestion.
    178     bool isValidMultipleWordSuggestion() const {
    179         // Treat suggestion as invalid if the current and the previous word are single character
    180         // words.
    181         const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength()
    182                 - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1;
    183         const int currentWordLen = getNodeCodePointCount();
    184         return (prevWordLen != 1 || currentWordLen != 1);
    185     }
    186 
    187     bool isFirstCharUppercase() const {
    188         const int c = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(0);
    189         return CharUtils::isAsciiUpper(c);
    190     }
    191 
    192     bool isCompletion(const int inputSize) const {
    193         return mDicNodeState.mDicNodeStateInput.getInputIndex(0) >= inputSize;
    194     }
    195 
    196     bool canDoLookAheadCorrection(const int inputSize) const {
    197         return mDicNodeState.mDicNodeStateInput.getInputIndex(0) < inputSize - 1;
    198     }
    199 
    200     // Used to get n-gram probability in DicNodeUtils.
    201     int getWordId() const {
    202         return mDicNodeProperties.getWordId();
    203     }
    204 
    205     const WordIdArrayView getPrevWordIds() const {
    206         return mDicNodeProperties.getPrevWordIds();
    207     }
    208 
    209     // Used in DicNodeUtils
    210     int getChildrenPtNodeArrayPos() const {
    211         return mDicNodeProperties.getChildrenPtNodeArrayPos();
    212     }
    213 
    214     AK_FORCE_INLINE bool isTerminalDicNode() const {
    215         const bool isTerminalPtNode = mDicNodeProperties.isTerminal();
    216         const int currentDicNodeDepth = getNodeCodePointCount();
    217         const int terminalDicNodeDepth = mDicNodeProperties.getLeavingDepth();
    218         return isTerminalPtNode && currentDicNodeDepth > 0
    219                 && currentDicNodeDepth == terminalDicNodeDepth;
    220     }
    221 
    222     bool shouldBeFilteredBySafetyNetForBigram() const {
    223         const uint16_t currentDepth = getNodeCodePointCount();
    224         const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength()
    225                 - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1;
    226         return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1));
    227     }
    228 
    229     bool hasMatchedOrProximityCodePoints() const {
    230         // This DicNode does not have matched or proximity code points when all code points have
    231         // been handled as edit corrections or completion so far.
    232         const int editCorrectionCount = mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount();
    233         const int completionCount = mDicNodeState.mDicNodeStateScoring.getCompletionCount();
    234         return (editCorrectionCount + completionCount) < getNodeCodePointCount();
    235     }
    236 
    237     bool isTotalInputSizeExceedingLimit() const {
    238         // TODO: 3 can be 2? Needs to be investigated.
    239         // TODO: Have a const variable for 3 (or 2)
    240         return getTotalNodeCodePointCount() > MAX_WORD_LENGTH - 3;
    241     }
    242 
    243     void outputResult(int *dest) const {
    244         memmove(dest, getOutputWordBuf(), getTotalNodeCodePointCount() * sizeof(dest[0]));
    245         DUMP_WORD_AND_SCORE("OUTPUT");
    246     }
    247 
    248     // "Total" in this context (and other methods in this class) means the whole suggestion. When
    249     // this represents a multi-word suggestion, the referenced PtNode (in mDicNodeState) is only
    250     // the one that corresponds to the last word of the suggestion, and all the previous words
    251     // are concatenated together in mDicNodeStateOutput.
    252     int getTotalNodeSpaceCount() const {
    253         if (!hasMultipleWords()) {
    254             return 0;
    255         }
    256         return CharUtils::getSpaceCount(mDicNodeState.mDicNodeStateOutput.getCodePointBuf(),
    257                 mDicNodeState.mDicNodeStateOutput.getPrevWordsLength());
    258     }
    259 
    260     int getSecondWordFirstInputIndex(const ProximityInfoState *const pInfoState) const {
    261         const int inputIndex = mDicNodeState.mDicNodeStateOutput.getSecondWordFirstInputIndex();
    262         if (inputIndex == NOT_AN_INDEX) {
    263             return NOT_AN_INDEX;
    264         } else {
    265             return pInfoState->getInputIndexOfSampledPoint(inputIndex);
    266         }
    267     }
    268 
    269     bool hasMultipleWords() const {
    270         return mDicNodeState.mDicNodeStateOutput.getPrevWordCount() > 0;
    271     }
    272 
    273     int getProximityCorrectionCount() const {
    274         return mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount();
    275     }
    276 
    277     int getEditCorrectionCount() const {
    278         return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount();
    279     }
    280 
    281     // Used to prune nodes
    282     float getNormalizedCompoundDistance() const {
    283         return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistance();
    284     }
    285 
    286     // Used to prune nodes
    287     float getNormalizedSpatialDistance() const {
    288         return mDicNodeState.mDicNodeStateScoring.getSpatialDistance()
    289                 / static_cast<float>(getInputIndex(0) + 1);
    290     }
    291 
    292     // Used to prune nodes
    293     float getCompoundDistance() const {
    294         return mDicNodeState.mDicNodeStateScoring.getCompoundDistance();
    295     }
    296 
    297     // Used to prune nodes
    298     float getCompoundDistance(const float weightOfLangModelVsSpatialModel) const {
    299         return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(
    300                 weightOfLangModelVsSpatialModel);
    301     }
    302 
    303     AK_FORCE_INLINE const int *getOutputWordBuf() const {
    304         return mDicNodeState.mDicNodeStateOutput.getCodePointBuf();
    305     }
    306 
    307     int getPrevCodePointG(int pointerId) const {
    308         return mDicNodeState.mDicNodeStateInput.getPrevCodePoint(pointerId);
    309     }
    310 
    311     // Whether the current codepoint can be an intentional omission, in which case the traversal
    312     // algorithm will always check for a possible omission here.
    313     bool canBeIntentionalOmission() const {
    314         return CharUtils::isIntentionalOmissionCodePoint(getNodeCodePoint());
    315     }
    316 
    317     // Whether the omission is so frequent that it should incur zero cost.
    318     bool isZeroCostOmission() const {
    319         // TODO: do not hardcode and read from header
    320         return (getNodeCodePoint() == KEYCODE_SINGLE_QUOTE);
    321     }
    322 
    323     // TODO: remove
    324     float getTerminalDiffCostG(int path) const {
    325         return mDicNodeState.mDicNodeStateInput.getTerminalDiffCost(path);
    326     }
    327 
    328     //////////////////////
    329     // Temporary getter //
    330     // TODO: Remove     //
    331     //////////////////////
    332     // TODO: Remove once touch path is merged into ProximityInfoState
    333     // Note: Returned codepoint may be a digraph codepoint if the node is in a composite glyph.
    334     int getNodeCodePoint() const {
    335         const int codePoint = mDicNodeProperties.getDicNodeCodePoint();
    336         const DigraphUtils::DigraphCodePointIndex digraphIndex =
    337                 mDicNodeState.mDicNodeStateScoring.getDigraphIndex();
    338         if (digraphIndex == DigraphUtils::NOT_A_DIGRAPH_INDEX) {
    339             return codePoint;
    340         }
    341         return DigraphUtils::getDigraphCodePointForIndex(codePoint, digraphIndex);
    342     }
    343 
    344     ////////////////////////////////
    345     // Utils for cost calculation //
    346     ////////////////////////////////
    347     AK_FORCE_INLINE bool isSameNodeCodePoint(const DicNode *const dicNode) const {
    348         return mDicNodeProperties.getDicNodeCodePoint()
    349                 == dicNode->mDicNodeProperties.getDicNodeCodePoint();
    350     }
    351 
    352     // TODO: remove
    353     // TODO: rename getNextInputIndex
    354     int16_t getInputIndex(int pointerId) const {
    355         return mDicNodeState.mDicNodeStateInput.getInputIndex(pointerId);
    356     }
    357 
    358     ////////////////////////////////////
    359     // Getter of features for scoring //
    360     ////////////////////////////////////
    361     float getSpatialDistanceForScoring() const {
    362         return mDicNodeState.mDicNodeStateScoring.getSpatialDistance();
    363     }
    364 
    365     // For space-aware gestures, we store the normalized distance at the char index
    366     // that ends the first word of the suggestion. We call this the distance after
    367     // first word.
    368     float getNormalizedCompoundDistanceAfterFirstWord() const {
    369         return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistanceAfterFirstWord();
    370     }
    371 
    372     float getRawLength() const {
    373         return mDicNodeState.mDicNodeStateScoring.getRawLength();
    374     }
    375 
    376     DoubleLetterLevel getDoubleLetterLevel() const {
    377         return mDicNodeState.mDicNodeStateScoring.getDoubleLetterLevel();
    378     }
    379 
    380     void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) {
    381         mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel);
    382     }
    383 
    384     bool isInDigraph() const {
    385         return mDicNodeState.mDicNodeStateScoring.getDigraphIndex()
    386                 != DigraphUtils::NOT_A_DIGRAPH_INDEX;
    387     }
    388 
    389     void advanceDigraphIndex() {
    390         mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex();
    391     }
    392 
    393     ErrorTypeUtils::ErrorType getContainedErrorTypes() const {
    394         return mDicNodeState.mDicNodeStateScoring.getContainedErrorTypes();
    395     }
    396 
    397     inline uint16_t getNodeCodePointCount() const {
    398         return mDicNodeProperties.getDepth();
    399     }
    400 
    401     // Returns code point count including spaces
    402     inline uint16_t getTotalNodeCodePointCount() const {
    403         return getNodeCodePointCount() + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength();
    404     }
    405 
    406     AK_FORCE_INLINE void dump(const char *tag) const {
    407 #if DEBUG_DICT
    408         DUMP_WORD_AND_SCORE(tag);
    409 #if DEBUG_DUMP_ERROR
    410         mProfiler.dump();
    411 #endif
    412 #endif
    413     }
    414 
    415     AK_FORCE_INLINE bool compare(const DicNode *right) const {
    416         // Promote exact matches to prevent them from being pruned.
    417         const bool leftExactMatch = ErrorTypeUtils::isExactMatch(getContainedErrorTypes());
    418         const bool rightExactMatch = ErrorTypeUtils::isExactMatch(right->getContainedErrorTypes());
    419         if (leftExactMatch != rightExactMatch) {
    420             return leftExactMatch;
    421         }
    422         const float diff =
    423                 right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance();
    424         static const float MIN_DIFF = 0.000001f;
    425         if (diff > MIN_DIFF) {
    426             return true;
    427         } else if (diff < -MIN_DIFF) {
    428             return false;
    429         }
    430         const int depth = getNodeCodePointCount();
    431         const int depthDiff = right->getNodeCodePointCount() - depth;
    432         if (depthDiff != 0) {
    433             return depthDiff > 0;
    434         }
    435         for (int i = 0; i < depth; ++i) {
    436             const int codePoint = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i);
    437             const int rightCodePoint =
    438                     right->mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i);
    439             if (codePoint != rightCodePoint) {
    440                 return rightCodePoint > codePoint;
    441             }
    442         }
    443         // Compare pointer values here for stable comparison
    444         return this > right;
    445     }
    446 
    447  private:
    448     DicNodeProperties mDicNodeProperties;
    449     DicNodeState mDicNodeState;
    450     // TODO: Remove
    451     bool mIsCachedForNextSuggestion;
    452 
    453     AK_FORCE_INLINE int getTotalInputIndex() const {
    454         int index = 0;
    455         for (int i = 0; i < MAX_POINTER_COUNT_G; i++) {
    456             index += mDicNodeState.mDicNodeStateInput.getInputIndex(i);
    457         }
    458         return index;
    459     }
    460 
    461     // Caveat: Must not be called outside Weighting
    462     // This restriction is guaranteed by "friend"
    463     AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost,
    464             const bool doNormalization, const int inputSize,
    465             const ErrorTypeUtils::ErrorType errorType) {
    466         if (DEBUG_GEO_FULL) {
    467             LOGI_SHOW_ADD_COST_PROP;
    468         }
    469         mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization,
    470                 inputSize, getTotalInputIndex(), errorType);
    471     }
    472 
    473     // Saves the current normalized compound distance for space-aware gestures.
    474     // See getNormalizedCompoundDistanceAfterFirstWord for details.
    475     AK_FORCE_INLINE void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() {
    476         mDicNodeState.mDicNodeStateScoring.saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet();
    477     }
    478 
    479     // Caveat: Must not be called outside Weighting
    480     // This restriction is guaranteed by "friend"
    481     AK_FORCE_INLINE void forwardInputIndex(const int pointerId, const int count,
    482             const bool overwritesPrevCodePointByNodeCodePoint) {
    483         if (count == 0) {
    484             return;
    485         }
    486         mDicNodeState.mDicNodeStateInput.forwardInputIndex(pointerId, count);
    487         if (overwritesPrevCodePointByNodeCodePoint) {
    488             mDicNodeState.mDicNodeStateInput.setPrevCodePoint(0, getNodeCodePoint());
    489         }
    490     }
    491 
    492     AK_FORCE_INLINE void updateInputIndexG(const DicNode_InputStateG *const inputStateG) {
    493         if (mDicNodeState.mDicNodeStateOutput.getPrevWordCount() == 1 && isFirstLetter()) {
    494             mDicNodeState.mDicNodeStateOutput.setSecondWordFirstInputIndex(
    495                     inputStateG->mInputIndex);
    496         }
    497         mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId,
    498                 inputStateG->mInputIndex, inputStateG->mPrevCodePoint,
    499                 inputStateG->mTerminalDiffCost, inputStateG->mRawLength);
    500         mDicNodeState.mDicNodeStateScoring.addRawLength(inputStateG->mRawLength);
    501         mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(inputStateG->mDoubleLetterLevel);
    502     }
    503 };
    504 } // namespace latinime
    505 #endif // LATINIME_DIC_NODE_H
    506