Home | History | Annotate | Download | only in src
      1 /*
      2  * Copyright (C) 2011 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include <assert.h>
     18 #include <ctype.h>
     19 #include <stdio.h>
     20 #include <string.h>
     21 
     22 #define LOG_TAG "LatinIME: correction.cpp"
     23 
     24 #include "correction.h"
     25 #include "dictionary.h"
     26 #include "proximity_info.h"
     27 
     28 namespace latinime {
     29 
     30 /////////////////////////////
     31 // edit distance funcitons //
     32 /////////////////////////////
     33 
     34 #if 0 /* no longer used */
     35 inline static int editDistance(
     36         int* editDistanceTable, const unsigned short* input,
     37         const int inputLength, const unsigned short* output, const int outputLength) {
     38     // dp[li][lo] dp[a][b] = dp[ a * lo + b]
     39     int* dp = editDistanceTable;
     40     const int li = inputLength + 1;
     41     const int lo = outputLength + 1;
     42     for (int i = 0; i < li; ++i) {
     43         dp[lo * i] = i;
     44     }
     45     for (int i = 0; i < lo; ++i) {
     46         dp[i] = i;
     47     }
     48 
     49     for (int i = 0; i < li - 1; ++i) {
     50         for (int j = 0; j < lo - 1; ++j) {
     51             const uint32_t ci = Dictionary::toBaseLowerCase(input[i]);
     52             const uint32_t co = Dictionary::toBaseLowerCase(output[j]);
     53             const uint16_t cost = (ci == co) ? 0 : 1;
     54             dp[(i + 1) * lo + (j + 1)] = min(dp[i * lo + (j + 1)] + 1,
     55                     min(dp[(i + 1) * lo + j] + 1, dp[i * lo + j] + cost));
     56             if (i > 0 && j > 0 && ci == Dictionary::toBaseLowerCase(output[j - 1])
     57                     && co == Dictionary::toBaseLowerCase(input[i - 1])) {
     58                 dp[(i + 1) * lo + (j + 1)] = min(
     59                         dp[(i + 1) * lo + (j + 1)], dp[(i - 1) * lo + (j - 1)] + cost);
     60             }
     61         }
     62     }
     63 
     64     if (DEBUG_EDIT_DISTANCE) {
     65         LOGI("IN = %d, OUT = %d", inputLength, outputLength);
     66         for (int i = 0; i < li; ++i) {
     67             for (int j = 0; j < lo; ++j) {
     68                 LOGI("EDIT[%d][%d], %d", i, j, dp[i * lo + j]);
     69             }
     70         }
     71     }
     72     return dp[li * lo - 1];
     73 }
     74 #endif
     75 
     76 inline static void initEditDistance(int *editDistanceTable) {
     77     for (int i = 0; i <= MAX_WORD_LENGTH_INTERNAL; ++i) {
     78         editDistanceTable[i] = i;
     79     }
     80 }
     81 
     82 inline static void calcEditDistanceOneStep(int *editDistanceTable, const unsigned short *input,
     83         const int inputLength, const unsigned short *output, const int outputLength) {
     84     // Let dp[i][j] be editDistanceTable[i * (inputLength + 1) + j].
     85     // Assuming that dp[0][0] ... dp[outputLength - 1][inputLength] are already calculated,
     86     // and calculate dp[ouputLength][0] ... dp[outputLength][inputLength].
     87     int *const current = editDistanceTable + outputLength * (inputLength + 1);
     88     const int *const prev = editDistanceTable + (outputLength - 1) * (inputLength + 1);
     89     const int *const prevprev =
     90             outputLength >= 2 ? editDistanceTable + (outputLength - 2) * (inputLength + 1) : NULL;
     91     current[0] = outputLength;
     92     const uint32_t co = Dictionary::toBaseLowerCase(output[outputLength - 1]);
     93     const uint32_t prevCO =
     94             outputLength >= 2 ? Dictionary::toBaseLowerCase(output[outputLength - 2]) : 0;
     95     for (int i = 1; i <= inputLength; ++i) {
     96         const uint32_t ci = Dictionary::toBaseLowerCase(input[i - 1]);
     97         const uint16_t cost = (ci == co) ? 0 : 1;
     98         current[i] = min(current[i - 1] + 1, min(prev[i] + 1, prev[i - 1] + cost));
     99         if (i >= 2 && prevprev && ci == prevCO
    100                 && co == Dictionary::toBaseLowerCase(input[i - 2])) {
    101             current[i] = min(current[i], prevprev[i - 2] + 1);
    102         }
    103     }
    104 }
    105 
    106 inline static int getCurrentEditDistance(
    107         int *editDistanceTable, const int inputLength, const int outputLength) {
    108     return editDistanceTable[(inputLength + 1) * (outputLength + 1) - 1];
    109 }
    110 
    111 //////////////////////
    112 // inline functions //
    113 //////////////////////
    114 static const char QUOTE = '\'';
    115 
    116 inline bool Correction::isQuote(const unsigned short c) {
    117     const unsigned short userTypedChar = mProximityInfo->getPrimaryCharAt(mInputIndex);
    118     return (c == QUOTE && userTypedChar != QUOTE);
    119 }
    120 
    121 ////////////////
    122 // Correction //
    123 ////////////////
    124 
    125 Correction::Correction(const int typedLetterMultiplier, const int fullWordMultiplier)
    126         : TYPED_LETTER_MULTIPLIER(typedLetterMultiplier), FULL_WORD_MULTIPLIER(fullWordMultiplier) {
    127     initEditDistance(mEditDistanceTable);
    128 }
    129 
    130 void Correction::initCorrection(const ProximityInfo *pi, const int inputLength,
    131         const int maxDepth) {
    132     mProximityInfo = pi;
    133     mInputLength = inputLength;
    134     mMaxDepth = maxDepth;
    135     mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2;
    136 }
    137 
    138 void Correction::initCorrectionState(
    139         const int rootPos, const int childCount, const bool traverseAll) {
    140     latinime::initCorrectionState(mCorrectionStates, rootPos, childCount, traverseAll);
    141     // TODO: remove
    142     mCorrectionStates[0].mTransposedPos = mTransposedPos;
    143     mCorrectionStates[0].mExcessivePos = mExcessivePos;
    144     mCorrectionStates[0].mSkipPos = mSkipPos;
    145 }
    146 
    147 void Correction::setCorrectionParams(const int skipPos, const int excessivePos,
    148         const int transposedPos, const int spaceProximityPos, const int missingSpacePos,
    149         const bool useFullEditDistance) {
    150     // TODO: remove
    151     mTransposedPos = transposedPos;
    152     mExcessivePos = excessivePos;
    153     mSkipPos = skipPos;
    154     // TODO: remove
    155     mCorrectionStates[0].mTransposedPos = transposedPos;
    156     mCorrectionStates[0].mExcessivePos = excessivePos;
    157     mCorrectionStates[0].mSkipPos = skipPos;
    158 
    159     mSpaceProximityPos = spaceProximityPos;
    160     mMissingSpacePos = missingSpacePos;
    161     mUseFullEditDistance = useFullEditDistance;
    162 }
    163 
    164 void Correction::checkState() {
    165     if (DEBUG_DICT) {
    166         int inputCount = 0;
    167         if (mSkipPos >= 0) ++inputCount;
    168         if (mExcessivePos >= 0) ++inputCount;
    169         if (mTransposedPos >= 0) ++inputCount;
    170         // TODO: remove this assert
    171         assert(inputCount <= 1);
    172     }
    173 }
    174 
    175 int Correction::getFreqForSplitTwoWords(const int firstFreq, const int secondFreq,
    176         const unsigned short *word) {
    177     return Correction::RankingAlgorithm::calcFreqForSplitTwoWords(
    178             firstFreq, secondFreq, this, word);
    179 }
    180 
    181 int Correction::getFinalFreq(const int freq, unsigned short **word, int *wordLength) {
    182     const int outputIndex = mTerminalOutputIndex;
    183     const int inputIndex = mTerminalInputIndex;
    184     *wordLength = outputIndex + 1;
    185     if (mProximityInfo->sameAsTyped(mWord, outputIndex + 1) || outputIndex < MIN_SUGGEST_DEPTH) {
    186         return -1;
    187     }
    188 
    189     *word = mWord;
    190     return Correction::RankingAlgorithm::calculateFinalFreq(
    191             inputIndex, outputIndex, freq, mEditDistanceTable, this);
    192 }
    193 
    194 bool Correction::initProcessState(const int outputIndex) {
    195     if (mCorrectionStates[outputIndex].mChildCount <= 0) {
    196         return false;
    197     }
    198     mOutputIndex = outputIndex;
    199     --(mCorrectionStates[outputIndex].mChildCount);
    200     mInputIndex = mCorrectionStates[outputIndex].mInputIndex;
    201     mNeedsToTraverseAllNodes = mCorrectionStates[outputIndex].mNeedsToTraverseAllNodes;
    202 
    203     mEquivalentCharCount = mCorrectionStates[outputIndex].mEquivalentCharCount;
    204     mProximityCount = mCorrectionStates[outputIndex].mProximityCount;
    205     mTransposedCount = mCorrectionStates[outputIndex].mTransposedCount;
    206     mExcessiveCount = mCorrectionStates[outputIndex].mExcessiveCount;
    207     mSkippedCount = mCorrectionStates[outputIndex].mSkippedCount;
    208     mLastCharExceeded = mCorrectionStates[outputIndex].mLastCharExceeded;
    209 
    210     mTransposedPos = mCorrectionStates[outputIndex].mTransposedPos;
    211     mExcessivePos = mCorrectionStates[outputIndex].mExcessivePos;
    212     mSkipPos = mCorrectionStates[outputIndex].mSkipPos;
    213 
    214     mMatching = false;
    215     mProximityMatching = false;
    216     mTransposing = false;
    217     mExceeding = false;
    218     mSkipping = false;
    219 
    220     return true;
    221 }
    222 
    223 int Correction::goDownTree(
    224         const int parentIndex, const int childCount, const int firstChildPos) {
    225     mCorrectionStates[mOutputIndex].mParentIndex = parentIndex;
    226     mCorrectionStates[mOutputIndex].mChildCount = childCount;
    227     mCorrectionStates[mOutputIndex].mSiblingPos = firstChildPos;
    228     return mOutputIndex;
    229 }
    230 
    231 // TODO: remove
    232 int Correction::getOutputIndex() {
    233     return mOutputIndex;
    234 }
    235 
    236 // TODO: remove
    237 int Correction::getInputIndex() {
    238     return mInputIndex;
    239 }
    240 
    241 // TODO: remove
    242 bool Correction::needsToTraverseAllNodes() {
    243     return mNeedsToTraverseAllNodes;
    244 }
    245 
    246 void Correction::incrementInputIndex() {
    247     ++mInputIndex;
    248 }
    249 
    250 void Correction::incrementOutputIndex() {
    251     ++mOutputIndex;
    252     mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex;
    253     mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount;
    254     mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos;
    255     mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex;
    256     mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes;
    257 
    258     mCorrectionStates[mOutputIndex].mEquivalentCharCount = mEquivalentCharCount;
    259     mCorrectionStates[mOutputIndex].mProximityCount = mProximityCount;
    260     mCorrectionStates[mOutputIndex].mTransposedCount = mTransposedCount;
    261     mCorrectionStates[mOutputIndex].mExcessiveCount = mExcessiveCount;
    262     mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount;
    263 
    264     mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos;
    265     mCorrectionStates[mOutputIndex].mTransposedPos = mTransposedPos;
    266     mCorrectionStates[mOutputIndex].mExcessivePos = mExcessivePos;
    267 
    268     mCorrectionStates[mOutputIndex].mLastCharExceeded = mLastCharExceeded;
    269 
    270     mCorrectionStates[mOutputIndex].mMatching = mMatching;
    271     mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching;
    272     mCorrectionStates[mOutputIndex].mTransposing = mTransposing;
    273     mCorrectionStates[mOutputIndex].mExceeding = mExceeding;
    274     mCorrectionStates[mOutputIndex].mSkipping = mSkipping;
    275 }
    276 
    277 void Correction::startToTraverseAllNodes() {
    278     mNeedsToTraverseAllNodes = true;
    279 }
    280 
    281 bool Correction::needsToPrune() const {
    282     // TODO: use edit distance here
    283     return mOutputIndex - 1 >= mMaxDepth || mProximityCount > mMaxEditDistance;
    284 }
    285 
    286 void Correction::addCharToCurrentWord(const int32_t c) {
    287     mWord[mOutputIndex] = c;
    288     const unsigned short *primaryInputWord = mProximityInfo->getPrimaryInputWord();
    289     calcEditDistanceOneStep(mEditDistanceTable, primaryInputWord, mInputLength,
    290             mWord, mOutputIndex + 1);
    291 }
    292 
    293 // TODO: inline?
    294 Correction::CorrectionType Correction::processSkipChar(
    295         const int32_t c, const bool isTerminal, const bool inputIndexIncremented) {
    296     addCharToCurrentWord(c);
    297     if (needsToTraverseAllNodes() && isTerminal) {
    298         mTerminalInputIndex = mInputIndex - (inputIndexIncremented ? 1 : 0);
    299         mTerminalOutputIndex = mOutputIndex;
    300         incrementOutputIndex();
    301         return TRAVERSE_ALL_ON_TERMINAL;
    302     } else {
    303         incrementOutputIndex();
    304         return TRAVERSE_ALL_NOT_ON_TERMINAL;
    305     }
    306 }
    307 
    308 inline bool isEquivalentChar(ProximityInfo::ProximityType type) {
    309     return type == ProximityInfo::EQUIVALENT_CHAR;
    310 }
    311 
    312 Correction::CorrectionType Correction::processCharAndCalcState(
    313         const int32_t c, const bool isTerminal) {
    314     const int correctionCount = (mSkippedCount + mExcessiveCount + mTransposedCount);
    315     // TODO: Change the limit if we'll allow two or more corrections
    316     const bool noCorrectionsHappenedSoFar = correctionCount == 0;
    317     const bool canTryCorrection = noCorrectionsHappenedSoFar;
    318     int proximityIndex = 0;
    319     mDistances[mOutputIndex] = NOT_A_DISTANCE;
    320 
    321     if (mNeedsToTraverseAllNodes || isQuote(c)) {
    322         bool incremented = false;
    323         if (mLastCharExceeded && mInputIndex == mInputLength - 1) {
    324             // TODO: Do not check the proximity if EditDistance exceeds the threshold
    325             const ProximityInfo::ProximityType matchId =
    326                     mProximityInfo->getMatchedProximityId(mInputIndex, c, true, &proximityIndex);
    327             if (isEquivalentChar(matchId)) {
    328                 mLastCharExceeded = false;
    329                 --mExcessiveCount;
    330                 mDistances[mOutputIndex] =
    331                         mProximityInfo->getNormalizedSquaredDistance(mInputIndex, 0);
    332             } else if (matchId == ProximityInfo::NEAR_PROXIMITY_CHAR) {
    333                 mLastCharExceeded = false;
    334                 --mExcessiveCount;
    335                 ++mProximityCount;
    336                 mDistances[mOutputIndex] =
    337                         mProximityInfo->getNormalizedSquaredDistance(mInputIndex, proximityIndex);
    338             }
    339             incrementInputIndex();
    340             incremented = true;
    341         }
    342         return processSkipChar(c, isTerminal, incremented);
    343     }
    344 
    345     if (mExcessivePos >= 0) {
    346         if (mExcessiveCount == 0 && mExcessivePos < mOutputIndex) {
    347             mExcessivePos = mOutputIndex;
    348         }
    349         if (mExcessivePos < mInputLength - 1) {
    350             mExceeding = mExcessivePos == mInputIndex && canTryCorrection;
    351         }
    352     }
    353 
    354     if (mSkipPos >= 0) {
    355         if (mSkippedCount == 0 && mSkipPos < mOutputIndex) {
    356             if (DEBUG_DICT) {
    357                 assert(mSkipPos == mOutputIndex - 1);
    358             }
    359             mSkipPos = mOutputIndex;
    360         }
    361         mSkipping = mSkipPos == mOutputIndex && canTryCorrection;
    362     }
    363 
    364     if (mTransposedPos >= 0) {
    365         if (mTransposedCount == 0 && mTransposedPos < mOutputIndex) {
    366             mTransposedPos = mOutputIndex;
    367         }
    368         if (mTransposedPos < mInputLength - 1) {
    369             mTransposing = mInputIndex == mTransposedPos && canTryCorrection;
    370         }
    371     }
    372 
    373     bool secondTransposing = false;
    374     if (mTransposedCount % 2 == 1) {
    375         if (isEquivalentChar(mProximityInfo->getMatchedProximityId(mInputIndex - 1, c, false))) {
    376             ++mTransposedCount;
    377             secondTransposing = true;
    378         } else if (mCorrectionStates[mOutputIndex].mExceeding) {
    379             --mTransposedCount;
    380             ++mExcessiveCount;
    381             --mExcessivePos;
    382             incrementInputIndex();
    383         } else {
    384             --mTransposedCount;
    385             if (DEBUG_CORRECTION) {
    386                 DUMP_WORD(mWord, mOutputIndex);
    387                 LOGI("UNRELATED(0): %d, %d, %d, %d, %c", mProximityCount, mSkippedCount,
    388                         mTransposedCount, mExcessiveCount, c);
    389             }
    390             return UNRELATED;
    391         }
    392     }
    393 
    394     // TODO: Change the limit if we'll allow two or more proximity chars with corrections
    395     const bool checkProximityChars = noCorrectionsHappenedSoFar ||  mProximityCount == 0;
    396     ProximityInfo::ProximityType matchedProximityCharId = secondTransposing
    397             ? ProximityInfo::EQUIVALENT_CHAR
    398             : mProximityInfo->getMatchedProximityId(
    399                     mInputIndex, c, checkProximityChars, &proximityIndex);
    400 
    401     if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) {
    402         if (canTryCorrection && mOutputIndex > 0
    403                 && mCorrectionStates[mOutputIndex].mProximityMatching
    404                 && mCorrectionStates[mOutputIndex].mExceeding
    405                 && isEquivalentChar(mProximityInfo->getMatchedProximityId(
    406                         mInputIndex, mWord[mOutputIndex - 1], false))) {
    407             if (DEBUG_CORRECTION) {
    408                 LOGI("CONVERSION p->e %c", mWord[mOutputIndex - 1]);
    409             }
    410             // Conversion p->e
    411             // Example:
    412             // wearth ->    earth
    413             // px     -> (E)mmmmm
    414             ++mExcessiveCount;
    415             --mProximityCount;
    416             mExcessivePos = mOutputIndex - 1;
    417             ++mInputIndex;
    418             // Here, we are doing something equivalent to matchedProximityCharId,
    419             // but we already know that "excessive char correction" just happened
    420             // so that we just need to check "mProximityCount == 0".
    421             matchedProximityCharId = mProximityInfo->getMatchedProximityId(
    422                     mInputIndex, c, mProximityCount == 0, &proximityIndex);
    423         }
    424     }
    425 
    426     if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) {
    427         // TODO: Optimize
    428         // As the current char turned out to be an unrelated char,
    429         // we will try other correction-types. Please note that mCorrectionStates[mOutputIndex]
    430         // here refers to the previous state.
    431         if (mInputIndex < mInputLength - 1 && mOutputIndex > 0 && mTransposedCount > 0
    432                 && !mCorrectionStates[mOutputIndex].mTransposing
    433                 && mCorrectionStates[mOutputIndex - 1].mTransposing
    434                 && isEquivalentChar(mProximityInfo->getMatchedProximityId(
    435                         mInputIndex, mWord[mOutputIndex - 1], false))
    436                 && isEquivalentChar(
    437                         mProximityInfo->getMatchedProximityId(mInputIndex + 1, c, false))) {
    438             // Conversion t->e
    439             // Example:
    440             // occaisional -> occa   sional
    441             // mmmmttx     -> mmmm(E)mmmmmm
    442             mTransposedCount -= 2;
    443             ++mExcessiveCount;
    444             ++mInputIndex;
    445         } else if (mOutputIndex > 0 && mInputIndex > 0 && mTransposedCount > 0
    446                 && !mCorrectionStates[mOutputIndex].mTransposing
    447                 && mCorrectionStates[mOutputIndex - 1].mTransposing
    448                 && isEquivalentChar(
    449                         mProximityInfo->getMatchedProximityId(mInputIndex - 1, c, false))) {
    450             // Conversion t->s
    451             // Example:
    452             // chcolate -> chocolate
    453             // mmttx    -> mmsmmmmmm
    454             mTransposedCount -= 2;
    455             ++mSkippedCount;
    456             --mInputIndex;
    457         } else if (canTryCorrection && mInputIndex > 0
    458                 && mCorrectionStates[mOutputIndex].mProximityMatching
    459                 && mCorrectionStates[mOutputIndex].mSkipping
    460                 && isEquivalentChar(
    461                         mProximityInfo->getMatchedProximityId(mInputIndex - 1, c, false))) {
    462             // Conversion p->s
    463             // Note: This logic tries saving cases like contrst --> contrast -- "a" is one of
    464             // proximity chars of "s", but it should rather be handled as a skipped char.
    465             ++mSkippedCount;
    466             --mProximityCount;
    467             return processSkipChar(c, isTerminal, false);
    468         } else if ((mExceeding || mTransposing) && mInputIndex - 1 < mInputLength
    469                 && isEquivalentChar(
    470                         mProximityInfo->getMatchedProximityId(mInputIndex + 1, c, false))) {
    471             // 1.2. Excessive or transpose correction
    472             if (mTransposing) {
    473                 ++mTransposedCount;
    474             } else {
    475                 ++mExcessiveCount;
    476                 incrementInputIndex();
    477             }
    478         } else if (mSkipping) {
    479             // 3. Skip correction
    480             ++mSkippedCount;
    481             return processSkipChar(c, isTerminal, false);
    482         } else {
    483             if (DEBUG_CORRECTION) {
    484                 DUMP_WORD(mWord, mOutputIndex);
    485                 LOGI("UNRELATED(1): %d, %d, %d, %d, %c", mProximityCount, mSkippedCount,
    486                         mTransposedCount, mExcessiveCount, c);
    487             }
    488             return UNRELATED;
    489         }
    490     } else if (secondTransposing) {
    491         // If inputIndex is greater than mInputLength, that means there is no
    492         // proximity chars. So, we don't need to check proximity.
    493         mMatching = true;
    494     } else if (isEquivalentChar(matchedProximityCharId)) {
    495         mMatching = true;
    496         ++mEquivalentCharCount;
    497         mDistances[mOutputIndex] = mProximityInfo->getNormalizedSquaredDistance(mInputIndex, 0);
    498     } else if (ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) {
    499         mProximityMatching = true;
    500         ++mProximityCount;
    501         mDistances[mOutputIndex] =
    502                 mProximityInfo->getNormalizedSquaredDistance(mInputIndex, proximityIndex);
    503     }
    504 
    505     addCharToCurrentWord(c);
    506 
    507     // 4. Last char excessive correction
    508     mLastCharExceeded = mExcessiveCount == 0 && mSkippedCount == 0 && mTransposedCount == 0
    509             && mProximityCount == 0 && (mInputIndex == mInputLength - 2);
    510     const bool isSameAsUserTypedLength = (mInputLength == mInputIndex + 1) || mLastCharExceeded;
    511     if (mLastCharExceeded) {
    512         ++mExcessiveCount;
    513     }
    514 
    515     // Start traversing all nodes after the index exceeds the user typed length
    516     if (isSameAsUserTypedLength) {
    517         startToTraverseAllNodes();
    518     }
    519 
    520     const bool needsToTryOnTerminalForTheLastPossibleExcessiveChar =
    521             mExceeding && mInputIndex == mInputLength - 2;
    522 
    523     // Finally, we are ready to go to the next character, the next "virtual node".
    524     // We should advance the input index.
    525     // We do this in this branch of the 'if traverseAllNodes' because we are still matching
    526     // characters to input; the other branch is not matching them but searching for
    527     // completions, this is why it does not have to do it.
    528     incrementInputIndex();
    529     // Also, the next char is one "virtual node" depth more than this char.
    530     incrementOutputIndex();
    531 
    532     if ((needsToTryOnTerminalForTheLastPossibleExcessiveChar
    533             || isSameAsUserTypedLength) && isTerminal) {
    534         mTerminalInputIndex = mInputIndex - 1;
    535         mTerminalOutputIndex = mOutputIndex - 1;
    536         if (DEBUG_CORRECTION) {
    537             DUMP_WORD(mWord, mOutputIndex);
    538             LOGI("ONTERMINAL(1): %d, %d, %d, %d, %c", mProximityCount, mSkippedCount,
    539                     mTransposedCount, mExcessiveCount, c);
    540         }
    541         return ON_TERMINAL;
    542     } else {
    543         return NOT_ON_TERMINAL;
    544     }
    545 }
    546 
    547 Correction::~Correction() {
    548 }
    549 
    550 /////////////////////////
    551 // static inline utils //
    552 /////////////////////////
    553 
    554 static const int TWO_31ST_DIV_255 = S_INT_MAX / 255;
    555 static inline int capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num) {
    556     return (num < TWO_31ST_DIV_255 ? 255 * num : S_INT_MAX);
    557 }
    558 
    559 static const int TWO_31ST_DIV_2 = S_INT_MAX / 2;
    560 inline static void multiplyIntCapped(const int multiplier, int *base) {
    561     const int temp = *base;
    562     if (temp != S_INT_MAX) {
    563         // Branch if multiplier == 2 for the optimization
    564         if (multiplier == 2) {
    565             *base = TWO_31ST_DIV_2 >= temp ? temp << 1 : S_INT_MAX;
    566         } else {
    567             // TODO: This overflow check gives a wrong answer when, for example,
    568             //       temp = 2^16 + 1 and multiplier = 2^17 + 1.
    569             //       Fix this behavior.
    570             const int tempRetval = temp * multiplier;
    571             *base = tempRetval >= temp ? tempRetval : S_INT_MAX;
    572         }
    573     }
    574 }
    575 
    576 inline static int powerIntCapped(const int base, const int n) {
    577     if (n <= 0) return 1;
    578     if (base == 2) {
    579         return n < 31 ? 1 << n : S_INT_MAX;
    580     } else {
    581         int ret = base;
    582         for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret);
    583         return ret;
    584     }
    585 }
    586 
    587 inline static void multiplyRate(const int rate, int *freq) {
    588     if (*freq != S_INT_MAX) {
    589         if (*freq > 1000000) {
    590             *freq /= 100;
    591             multiplyIntCapped(rate, freq);
    592         } else {
    593             multiplyIntCapped(rate, freq);
    594             *freq /= 100;
    595         }
    596     }
    597 }
    598 
    599 inline static int getQuoteCount(const unsigned short* word, const int length) {
    600     int quoteCount = 0;
    601     for (int i = 0; i < length; ++i) {
    602         if(word[i] == '\'') {
    603             ++quoteCount;
    604         }
    605     }
    606     return quoteCount;
    607 }
    608 
    609 inline static bool isUpperCase(unsigned short c) {
    610      if (c < sizeof(BASE_CHARS) / sizeof(BASE_CHARS[0])) {
    611          c = BASE_CHARS[c];
    612      }
    613      if (isupper(c)) {
    614          return true;
    615      }
    616      return false;
    617 }
    618 
    619 //////////////////////
    620 // RankingAlgorithm //
    621 //////////////////////
    622 
    623 /* static */
    624 int Correction::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const int outputIndex,
    625         const int freq, int* editDistanceTable, const Correction* correction) {
    626     const int excessivePos = correction->getExcessivePos();
    627     const int inputLength = correction->mInputLength;
    628     const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER;
    629     const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER;
    630     const ProximityInfo *proximityInfo = correction->mProximityInfo;
    631     const int skippedCount = correction->mSkippedCount;
    632     const int transposedCount = correction->mTransposedCount / 2;
    633     const int excessiveCount = correction->mExcessiveCount + correction->mTransposedCount % 2;
    634     const int proximityMatchedCount = correction->mProximityCount;
    635     const bool lastCharExceeded = correction->mLastCharExceeded;
    636     const bool useFullEditDistance = correction->mUseFullEditDistance;
    637     const int outputLength = outputIndex + 1;
    638     if (skippedCount >= inputLength || inputLength == 0) {
    639         return -1;
    640     }
    641 
    642     // TODO: find more robust way
    643     bool sameLength = lastCharExceeded ? (inputLength == inputIndex + 2)
    644             : (inputLength == inputIndex + 1);
    645 
    646     // TODO: use mExcessiveCount
    647     const int matchCount = inputLength - correction->mProximityCount - excessiveCount;
    648 
    649     const unsigned short* word = correction->mWord;
    650     const bool skipped = skippedCount > 0;
    651 
    652     const int quoteDiffCount = max(0, getQuoteCount(word, outputIndex + 1)
    653             - getQuoteCount(proximityInfo->getPrimaryInputWord(), inputLength));
    654 
    655     // TODO: Calculate edit distance for transposed and excessive
    656     int ed = 0;
    657     int adjustedProximityMatchedCount = proximityMatchedCount;
    658 
    659     int finalFreq = freq;
    660 
    661     // TODO: Optimize this.
    662     // TODO: Ignoring edit distance for transposed char, for now
    663     if (transposedCount == 0 && (proximityMatchedCount > 0 || skipped || excessiveCount > 0)) {
    664         ed = getCurrentEditDistance(editDistanceTable, inputLength, outputIndex + 1);
    665         const int matchWeight = powerIntCapped(typedLetterMultiplier,
    666                 max(inputLength, outputIndex + 1) - ed);
    667         multiplyIntCapped(matchWeight, &finalFreq);
    668 
    669         // TODO: Demote further if there are two or more excessive chars with longer user input?
    670         if (inputLength > outputIndex + 1) {
    671             multiplyRate(INPUT_EXCEEDS_OUTPUT_DEMOTION_RATE, &finalFreq);
    672         }
    673 
    674         ed = max(0, ed - quoteDiffCount);
    675 
    676         if (ed == 1 && (inputLength == outputIndex || inputLength == outputIndex + 2)) {
    677             // Promote a word with just one skipped or excessive char
    678             if (sameLength) {
    679                 multiplyRate(WORDS_WITH_JUST_ONE_CORRECTION_PROMOTION_RATE, &finalFreq);
    680             } else {
    681                 multiplyIntCapped(typedLetterMultiplier, &finalFreq);
    682             }
    683         } else if (ed == 0) {
    684             multiplyIntCapped(typedLetterMultiplier, &finalFreq);
    685             sameLength = true;
    686         }
    687         adjustedProximityMatchedCount = min(max(0, ed - (outputIndex + 1 - inputLength)),
    688                 proximityMatchedCount);
    689     } else {
    690         // TODO: Calculate the edit distance for transposed char
    691         const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount);
    692         multiplyIntCapped(matchWeight, &finalFreq);
    693     }
    694 
    695     if (proximityInfo->getMatchedProximityId(0, word[0], true)
    696             == ProximityInfo::UNRELATED_CHAR) {
    697         multiplyRate(FIRST_CHAR_DIFFERENT_DEMOTION_RATE, &finalFreq);
    698     }
    699 
    700     ///////////////////////////////////////////////
    701     // Promotion and Demotion for each correction
    702 
    703     // Demotion for a word with missing character
    704     if (skipped) {
    705         const int demotionRate = WORDS_WITH_MISSING_CHARACTER_DEMOTION_RATE
    706                 * (10 * inputLength - WORDS_WITH_MISSING_CHARACTER_DEMOTION_START_POS_10X)
    707                 / (10 * inputLength
    708                         - WORDS_WITH_MISSING_CHARACTER_DEMOTION_START_POS_10X + 10);
    709         if (DEBUG_DICT_FULL) {
    710             LOGI("Demotion rate for missing character is %d.", demotionRate);
    711         }
    712         multiplyRate(demotionRate, &finalFreq);
    713     }
    714 
    715     // Demotion for a word with transposed character
    716     if (transposedCount > 0) multiplyRate(
    717             WORDS_WITH_TRANSPOSED_CHARACTERS_DEMOTION_RATE, &finalFreq);
    718 
    719     // Demotion for a word with excessive character
    720     if (excessiveCount > 0) {
    721         multiplyRate(WORDS_WITH_EXCESSIVE_CHARACTER_DEMOTION_RATE, &finalFreq);
    722         if (!lastCharExceeded && !proximityInfo->existsAdjacentProximityChars(excessivePos)) {
    723             if (DEBUG_CORRECTION_FREQ) {
    724                 LOGI("Double excessive demotion");
    725             }
    726             // If an excessive character is not adjacent to the left char or the right char,
    727             // we will demote this word.
    728             multiplyRate(WORDS_WITH_EXCESSIVE_CHARACTER_OUT_OF_PROXIMITY_DEMOTION_RATE, &finalFreq);
    729         }
    730     }
    731 
    732     // Score calibration by touch coordinates is being done only for pure-fat finger typing error
    733     // cases.
    734     // TODO: Remove this constraint.
    735     if (CALIBRATE_SCORE_BY_TOUCH_COORDINATES && proximityInfo->touchPositionCorrectionEnabled()
    736             && skippedCount == 0 && excessiveCount == 0 && transposedCount == 0) {
    737         for (int i = 0; i < outputLength; ++i) {
    738             const int squaredDistance = correction->mDistances[i];
    739             if (i < adjustedProximityMatchedCount) {
    740                 multiplyIntCapped(typedLetterMultiplier, &finalFreq);
    741             }
    742             if (squaredDistance >= 0) {
    743                 // Promote or demote the score according to the distance from the sweet spot
    744                 static const float A = ZERO_DISTANCE_PROMOTION_RATE / 100.0f;
    745                 static const float B = 1.0f;
    746                 static const float C = 0.5f;
    747                 static const float R1 = NEUTRAL_SCORE_SQUARED_RADIUS;
    748                 static const float R2 = HALF_SCORE_SQUARED_RADIUS;
    749                 const float x = (float)squaredDistance
    750                         / ProximityInfo::NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR;
    751                 const float factor = (x < R1)
    752                     ? (A * (R1 - x) + B * x) / R1
    753                     : (B * (R2 - x) + C * (x - R1)) / (R2 - R1);
    754                 // factor is piecewise linear function like:
    755                 // A -_                  .
    756                 //     ^-_               .
    757                 // B      \              .
    758                 //         \             .
    759                 // C        \            .
    760                 //   0   R1 R2
    761                 if (factor <= 0) {
    762                     return -1;
    763                 }
    764                 multiplyRate((int)(factor * 100), &finalFreq);
    765             } else if (squaredDistance == PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO) {
    766                 multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq);
    767             }
    768         }
    769     } else {
    770         // Promotion for a word with proximity characters
    771         for (int i = 0; i < adjustedProximityMatchedCount; ++i) {
    772             // A word with proximity corrections
    773             if (DEBUG_DICT_FULL) {
    774                 LOGI("Found a proximity correction.");
    775             }
    776             multiplyIntCapped(typedLetterMultiplier, &finalFreq);
    777             multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq);
    778         }
    779     }
    780 
    781     const int errorCount = adjustedProximityMatchedCount > 0
    782             ? adjustedProximityMatchedCount
    783             : (proximityMatchedCount + transposedCount);
    784     multiplyRate(
    785             100 - CORRECTION_COUNT_RATE_DEMOTION_RATE_BASE * errorCount / inputLength, &finalFreq);
    786 
    787     // Promotion for an exactly matched word
    788     if (ed == 0) {
    789         // Full exact match
    790         if (sameLength && transposedCount == 0 && !skipped && excessiveCount == 0) {
    791             finalFreq = capped255MultForFullMatchAccentsOrCapitalizationDifference(finalFreq);
    792         }
    793     }
    794 
    795     // Promote a word with no correction
    796     if (proximityMatchedCount == 0 && transposedCount == 0 && !skipped && excessiveCount == 0) {
    797         multiplyRate(FULL_MATCHED_WORDS_PROMOTION_RATE, &finalFreq);
    798     }
    799 
    800     // TODO: Check excessive count and transposed count
    801     // TODO: Remove this if possible
    802     /*
    803          If the last character of the user input word is the same as the next character
    804          of the output word, and also all of characters of the user input are matched
    805          to the output word, we'll promote that word a bit because
    806          that word can be considered the combination of skipped and matched characters.
    807          This means that the 'sm' pattern wins over the 'ma' pattern.
    808          e.g.)
    809          shel -> shell [mmmma] or [mmmsm]
    810          hel -> hello [mmmaa] or [mmsma]
    811          m ... matching
    812          s ... skipping
    813          a ... traversing all
    814          t ... transposing
    815          e ... exceeding
    816          p ... proximity matching
    817      */
    818     if (matchCount == inputLength && matchCount >= 2 && !skipped
    819             && word[matchCount] == word[matchCount - 1]) {
    820         multiplyRate(WORDS_WITH_MATCH_SKIP_PROMOTION_RATE, &finalFreq);
    821     }
    822 
    823     // TODO: Do not use sameLength?
    824     if (sameLength) {
    825         multiplyIntCapped(fullWordMultiplier, &finalFreq);
    826     }
    827 
    828     if (useFullEditDistance && outputLength > inputLength + 1) {
    829         const int diff = outputLength - inputLength - 1;
    830         const int divider = diff < 31 ? 1 << diff : S_INT_MAX;
    831         finalFreq = divider > finalFreq ? 1 : finalFreq / divider;
    832     }
    833 
    834     if (DEBUG_DICT_FULL) {
    835         LOGI("calc: %d, %d", outputIndex, sameLength);
    836     }
    837 
    838     if (DEBUG_CORRECTION_FREQ) {
    839         DUMP_WORD(correction->mWord, outputIndex + 1);
    840         LOGI("FinalFreq: [P%d, S%d, T%d, E%d] %d, %d, %d, %d, %d", proximityMatchedCount,
    841                 skippedCount, transposedCount, excessiveCount, lastCharExceeded, sameLength,
    842                 quoteDiffCount, ed, finalFreq);
    843     }
    844 
    845     return finalFreq;
    846 }
    847 
    848 /* static */
    849 int Correction::RankingAlgorithm::calcFreqForSplitTwoWords(
    850         const int firstFreq, const int secondFreq, const Correction* correction,
    851         const unsigned short *word) {
    852     const int spaceProximityPos = correction->mSpaceProximityPos;
    853     const int missingSpacePos = correction->mMissingSpacePos;
    854     if (DEBUG_DICT) {
    855         int inputCount = 0;
    856         if (spaceProximityPos >= 0) ++inputCount;
    857         if (missingSpacePos >= 0) ++inputCount;
    858         assert(inputCount <= 1);
    859     }
    860     const bool isSpaceProximity = spaceProximityPos >= 0;
    861     const int inputLength = correction->mInputLength;
    862     const int firstWordLength = isSpaceProximity ? spaceProximityPos : missingSpacePos;
    863     const int secondWordLength = isSpaceProximity ? (inputLength - spaceProximityPos - 1)
    864             : (inputLength - missingSpacePos);
    865     const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER;
    866 
    867     bool firstCapitalizedWordDemotion = false;
    868     if (firstWordLength >= 2) {
    869         firstCapitalizedWordDemotion = isUpperCase(word[0]);
    870     }
    871 
    872     bool secondCapitalizedWordDemotion = false;
    873     if (secondWordLength >= 2) {
    874         secondCapitalizedWordDemotion = isUpperCase(word[firstWordLength + 1]);
    875     }
    876 
    877     const bool capitalizedWordDemotion =
    878             firstCapitalizedWordDemotion ^ secondCapitalizedWordDemotion;
    879 
    880     if (DEBUG_DICT_FULL) {
    881         LOGI("Two words: %c, %c, %d", word[0], word[firstWordLength + 1], capitalizedWordDemotion);
    882     }
    883 
    884     if (firstWordLength == 0 || secondWordLength == 0) {
    885         return 0;
    886     }
    887     const int firstDemotionRate = 100 - 100 / (firstWordLength + 1);
    888     int tempFirstFreq = firstFreq;
    889     multiplyRate(firstDemotionRate, &tempFirstFreq);
    890 
    891     const int secondDemotionRate = 100 - 100 / (secondWordLength + 1);
    892     int tempSecondFreq = secondFreq;
    893     multiplyRate(secondDemotionRate, &tempSecondFreq);
    894 
    895     const int totalLength = firstWordLength + secondWordLength;
    896 
    897     // Promote pairFreq with multiplying by 2, because the word length is the same as the typed
    898     // length.
    899     int totalFreq = tempFirstFreq + tempSecondFreq;
    900 
    901     // This is a workaround to try offsetting the not-enough-demotion which will be done in
    902     // calcNormalizedScore in Utils.java.
    903     // In calcNormalizedScore the score will be demoted by (1 - 1 / length)
    904     // but we demoted only (1 - 1 / (length + 1)) so we will additionally adjust freq by
    905     // (1 - 1 / length) / (1 - 1 / (length + 1)) = (1 - 1 / (length * length))
    906     const int normalizedScoreNotEnoughDemotionAdjustment = 100 - 100 / (totalLength * totalLength);
    907     multiplyRate(normalizedScoreNotEnoughDemotionAdjustment, &totalFreq);
    908 
    909     // At this moment, totalFreq is calculated by the following formula:
    910     // (firstFreq * (1 - 1 / (firstWordLength + 1)) + secondFreq * (1 - 1 / (secondWordLength + 1)))
    911     //        * (1 - 1 / totalLength) / (1 - 1 / (totalLength + 1))
    912 
    913     multiplyIntCapped(powerIntCapped(typedLetterMultiplier, totalLength), &totalFreq);
    914 
    915     // This is another workaround to offset the demotion which will be done in
    916     // calcNormalizedScore in Utils.java.
    917     // In calcNormalizedScore the score will be demoted by (1 - 1 / length) so we have to promote
    918     // the same amount because we already have adjusted the synthetic freq of this "missing or
    919     // mistyped space" suggestion candidate above in this method.
    920     const int normalizedScoreDemotionRateOffset = (100 + 100 / totalLength);
    921     multiplyRate(normalizedScoreDemotionRateOffset, &totalFreq);
    922 
    923     if (isSpaceProximity) {
    924         // A word pair with one space proximity correction
    925         if (DEBUG_DICT) {
    926             LOGI("Found a word pair with space proximity correction.");
    927         }
    928         multiplyIntCapped(typedLetterMultiplier, &totalFreq);
    929         multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &totalFreq);
    930     }
    931 
    932     multiplyRate(WORDS_WITH_MISSING_SPACE_CHARACTER_DEMOTION_RATE, &totalFreq);
    933 
    934     if (capitalizedWordDemotion) {
    935         multiplyRate(TWO_WORDS_CAPITALIZED_DEMOTION_RATE, &totalFreq);
    936     }
    937 
    938     return totalFreq;
    939 }
    940 
    941 } // namespace latinime
    942