Home | History | Annotate | Download | only in src
      1 /*
      2  * Copyright (C) 2011 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LATINIME_CORRECTION_H
     18 #define LATINIME_CORRECTION_H
     19 
     20 #include <cstring> // for memset()
     21 
     22 #include "correction_state.h"
     23 #include "defines.h"
     24 #include "proximity_info_state.h"
     25 
     26 namespace latinime {
     27 
     28 class ProximityInfo;
     29 
     30 class Correction {
     31  public:
     32     typedef enum {
     33         TRAVERSE_ALL_ON_TERMINAL,
     34         TRAVERSE_ALL_NOT_ON_TERMINAL,
     35         UNRELATED,
     36         ON_TERMINAL,
     37         NOT_ON_TERMINAL
     38     } CorrectionType;
     39 
     40     Correction()
     41             : mProximityInfo(0), mUseFullEditDistance(false), mDoAutoCompletion(false),
     42               mMaxEditDistance(0), mMaxDepth(0), mInputSize(0), mSpaceProximityPos(0),
     43               mMissingSpacePos(0), mTerminalInputIndex(0), mTerminalOutputIndex(0), mMaxErrors(0),
     44               mTotalTraverseCount(0), mNeedsToTraverseAllNodes(false), mOutputIndex(0),
     45               mInputIndex(0), mEquivalentCharCount(0), mProximityCount(0), mExcessiveCount(0),
     46               mTransposedCount(0), mSkippedCount(0), mTransposedPos(0), mExcessivePos(0),
     47               mSkipPos(0), mLastCharExceeded(false), mMatching(false), mProximityMatching(false),
     48               mAdditionalProximityMatching(false), mExceeding(false), mTransposing(false),
     49               mSkipping(false), mProximityInfoState() {
     50         memset(mWord, 0, sizeof(mWord));
     51         memset(mDistances, 0, sizeof(mDistances));
     52         memset(mEditDistanceTable, 0, sizeof(mEditDistanceTable));
     53         // NOTE: mCorrectionStates is an array of instances.
     54         // No need to initialize it explicitly here.
     55     }
     56 
     57     // Non virtual inline destructor -- never inherit this class
     58     ~Correction() {}
     59     void resetCorrection();
     60     void initCorrection(const ProximityInfo *pi, const int inputSize, const int maxDepth);
     61     void initCorrectionState(const int rootPos, const int childCount, const bool traverseAll);
     62 
     63     // TODO: remove
     64     void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos,
     65             const int spaceProximityPos, const int missingSpacePos, const bool useFullEditDistance,
     66             const bool doAutoCompletion, const int maxErrors);
     67     void checkState() const;
     68     bool sameAsTyped() const;
     69     bool initProcessState(const int index);
     70 
     71     int getInputIndex() const;
     72 
     73     bool needsToPrune() const;
     74 
     75     int pushAndGetTotalTraverseCount() {
     76         return ++mTotalTraverseCount;
     77     }
     78 
     79     int getFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
     80             const int wordCount, const bool isSpaceProximity, const int *word) const;
     81     int getFinalProbability(const int probability, int **word, int *wordLength);
     82     int getFinalProbabilityForSubQueue(const int probability, int **word, int *wordLength,
     83             const int inputSize);
     84 
     85     CorrectionType processCharAndCalcState(const int c, const bool isTerminal);
     86 
     87     /////////////////////////
     88     // Tree helper methods
     89     int goDownTree(const int parentIndex, const int childCount, const int firstChildPos);
     90 
     91     inline int getTreeSiblingPos(const int index) const {
     92         return mCorrectionStates[index].mSiblingPos;
     93     }
     94 
     95     inline void setTreeSiblingPos(const int index, const int pos) {
     96         mCorrectionStates[index].mSiblingPos = pos;
     97     }
     98 
     99     inline int getTreeParentIndex(const int index) const {
    100         return mCorrectionStates[index].mParentIndex;
    101     }
    102 
    103     class RankingAlgorithm {
    104      public:
    105         static int calculateFinalProbability(const int inputIndex, const int depth,
    106                 const int probability, int *editDistanceTable, const Correction *correction,
    107                 const int inputSize);
    108         static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
    109                 const int wordCount, const Correction *correction, const bool isSpaceProximity,
    110                 const int *word);
    111         static float calcNormalizedScore(const int *before, const int beforeLength,
    112                 const int *after, const int afterLength, const int score);
    113         static int editDistance(const int *before, const int beforeLength, const int *after,
    114                 const int afterLength);
    115      private:
    116         static const int MAX_INITIAL_SCORE = 255;
    117     };
    118 
    119     // proximity info state
    120     void initInputParams(const ProximityInfo *proximityInfo, const int *inputCodes,
    121             const int inputSize, const int *xCoordinates, const int *yCoordinates) {
    122         mProximityInfoState.initInputParams(0, static_cast<float>(MAX_VALUE_FOR_WEIGHTING),
    123                 proximityInfo, inputCodes, inputSize, xCoordinates, yCoordinates, 0, 0, false);
    124     }
    125 
    126     const int *getPrimaryInputWord() const {
    127         return mProximityInfoState.getPrimaryInputWord();
    128     }
    129 
    130     int getPrimaryCodePointAt(const int index) const {
    131         return mProximityInfoState.getPrimaryCodePointAt(index);
    132     }
    133 
    134  private:
    135     DISALLOW_COPY_AND_ASSIGN(Correction);
    136 
    137     /////////////////////////
    138     // static inline utils //
    139     /////////////////////////
    140     static const int TWO_31ST_DIV_255 = S_INT_MAX / 255;
    141     static inline int capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num) {
    142         return (num < TWO_31ST_DIV_255 ? 255 * num : S_INT_MAX);
    143     }
    144 
    145     static const int TWO_31ST_DIV_2 = S_INT_MAX / 2;
    146     AK_FORCE_INLINE static void multiplyIntCapped(const int multiplier, int *base) {
    147         const int temp = *base;
    148         if (temp != S_INT_MAX) {
    149             // Branch if multiplier == 2 for the optimization
    150             if (multiplier < 0) {
    151                 if (DEBUG_DICT) {
    152                     ASSERT(false);
    153                 }
    154                 AKLOGI("--- Invalid multiplier: %d", multiplier);
    155             } else if (multiplier == 0) {
    156                 *base = 0;
    157             } else if (multiplier == 2) {
    158                 *base = TWO_31ST_DIV_2 >= temp ? temp << 1 : S_INT_MAX;
    159             } else {
    160                 // TODO: This overflow check gives a wrong answer when, for example,
    161                 //       temp = 2^16 + 1 and multiplier = 2^17 + 1.
    162                 //       Fix this behavior.
    163                 const int tempRetval = temp * multiplier;
    164                 *base = tempRetval >= temp ? tempRetval : S_INT_MAX;
    165             }
    166         }
    167     }
    168 
    169     AK_FORCE_INLINE static int powerIntCapped(const int base, const int n) {
    170         if (n <= 0) return 1;
    171         if (base == 2) {
    172             return n < 31 ? 1 << n : S_INT_MAX;
    173         }
    174         int ret = base;
    175         for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret);
    176         return ret;
    177     }
    178 
    179     AK_FORCE_INLINE static void multiplyRate(const int rate, int *freq) {
    180         if (*freq != S_INT_MAX) {
    181             if (*freq > 1000000) {
    182                 *freq /= 100;
    183                 multiplyIntCapped(rate, freq);
    184             } else {
    185                 multiplyIntCapped(rate, freq);
    186                 *freq /= 100;
    187             }
    188         }
    189     }
    190 
    191     inline int getSpaceProximityPos() const {
    192         return mSpaceProximityPos;
    193     }
    194     inline int getMissingSpacePos() const {
    195         return mMissingSpacePos;
    196     }
    197 
    198     inline int getSkipPos() const {
    199         return mSkipPos;
    200     }
    201 
    202     inline int getExcessivePos() const {
    203         return mExcessivePos;
    204     }
    205 
    206     inline int getTransposedPos() const {
    207         return mTransposedPos;
    208     }
    209 
    210     inline void incrementInputIndex();
    211     inline void incrementOutputIndex();
    212     inline void startToTraverseAllNodes();
    213     inline bool isSingleQuote(const int c);
    214     inline CorrectionType processSkipChar(const int c, const bool isTerminal,
    215             const bool inputIndexIncremented);
    216     inline CorrectionType processUnrelatedCorrectionType();
    217     inline void addCharToCurrentWord(const int c);
    218     inline int getFinalProbabilityInternal(const int probability, int **word, int *wordLength,
    219             const int inputSize);
    220 
    221     static const int TYPED_LETTER_MULTIPLIER = 2;
    222     static const int FULL_WORD_MULTIPLIER = 2;
    223     const ProximityInfo *mProximityInfo;
    224 
    225     bool mUseFullEditDistance;
    226     bool mDoAutoCompletion;
    227     int mMaxEditDistance;
    228     int mMaxDepth;
    229     int mInputSize;
    230     int mSpaceProximityPos;
    231     int mMissingSpacePos;
    232     int mTerminalInputIndex;
    233     int mTerminalOutputIndex;
    234     int mMaxErrors;
    235 
    236     int mTotalTraverseCount;
    237 
    238     // The following arrays are state buffer.
    239     int mWord[MAX_WORD_LENGTH];
    240     int mDistances[MAX_WORD_LENGTH];
    241 
    242     // Edit distance calculation requires a buffer with (N+1)^2 length for the input length N.
    243     // Caveat: Do not create multiple tables per thread as this table eats up RAM a lot.
    244     int mEditDistanceTable[(MAX_WORD_LENGTH + 1) * (MAX_WORD_LENGTH + 1)];
    245 
    246     CorrectionState mCorrectionStates[MAX_WORD_LENGTH];
    247 
    248     // The following member variables are being used as cache values of the correction state.
    249     bool mNeedsToTraverseAllNodes;
    250     int mOutputIndex;
    251     int mInputIndex;
    252 
    253     int mEquivalentCharCount;
    254     int mProximityCount;
    255     int mExcessiveCount;
    256     int mTransposedCount;
    257     int mSkippedCount;
    258 
    259     int mTransposedPos;
    260     int mExcessivePos;
    261     int mSkipPos;
    262 
    263     bool mLastCharExceeded;
    264 
    265     bool mMatching;
    266     bool mProximityMatching;
    267     bool mAdditionalProximityMatching;
    268     bool mExceeding;
    269     bool mTransposing;
    270     bool mSkipping;
    271     ProximityInfoState mProximityInfoState;
    272 };
    273 
    274 inline void Correction::incrementInputIndex() {
    275     ++mInputIndex;
    276 }
    277 
    278 AK_FORCE_INLINE void Correction::incrementOutputIndex() {
    279     ++mOutputIndex;
    280     mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex;
    281     mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount;
    282     mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos;
    283     mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex;
    284     mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes;
    285 
    286     mCorrectionStates[mOutputIndex].mEquivalentCharCount = mEquivalentCharCount;
    287     mCorrectionStates[mOutputIndex].mProximityCount = mProximityCount;
    288     mCorrectionStates[mOutputIndex].mTransposedCount = mTransposedCount;
    289     mCorrectionStates[mOutputIndex].mExcessiveCount = mExcessiveCount;
    290     mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount;
    291 
    292     mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos;
    293     mCorrectionStates[mOutputIndex].mTransposedPos = mTransposedPos;
    294     mCorrectionStates[mOutputIndex].mExcessivePos = mExcessivePos;
    295 
    296     mCorrectionStates[mOutputIndex].mLastCharExceeded = mLastCharExceeded;
    297 
    298     mCorrectionStates[mOutputIndex].mMatching = mMatching;
    299     mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching;
    300     mCorrectionStates[mOutputIndex].mAdditionalProximityMatching = mAdditionalProximityMatching;
    301     mCorrectionStates[mOutputIndex].mTransposing = mTransposing;
    302     mCorrectionStates[mOutputIndex].mExceeding = mExceeding;
    303     mCorrectionStates[mOutputIndex].mSkipping = mSkipping;
    304 }
    305 
    306 inline void Correction::startToTraverseAllNodes() {
    307     mNeedsToTraverseAllNodes = true;
    308 }
    309 
    310 AK_FORCE_INLINE bool Correction::isSingleQuote(const int c) {
    311     const int userTypedChar = mProximityInfoState.getPrimaryCodePointAt(mInputIndex);
    312     return (c == KEYCODE_SINGLE_QUOTE && userTypedChar != KEYCODE_SINGLE_QUOTE);
    313 }
    314 
    315 AK_FORCE_INLINE Correction::CorrectionType Correction::processSkipChar(const int c,
    316         const bool isTerminal, const bool inputIndexIncremented) {
    317     addCharToCurrentWord(c);
    318     mTerminalInputIndex = mInputIndex - (inputIndexIncremented ? 1 : 0);
    319     mTerminalOutputIndex = mOutputIndex;
    320     incrementOutputIndex();
    321     if (mNeedsToTraverseAllNodes && isTerminal) {
    322         return TRAVERSE_ALL_ON_TERMINAL;
    323     }
    324     return TRAVERSE_ALL_NOT_ON_TERMINAL;
    325 }
    326 
    327 inline Correction::CorrectionType Correction::processUnrelatedCorrectionType() {
    328     // Needs to set mTerminalInputIndex and mTerminalOutputIndex before returning any CorrectionType
    329     mTerminalInputIndex = mInputIndex;
    330     mTerminalOutputIndex = mOutputIndex;
    331     return UNRELATED;
    332 }
    333 
    334 AK_FORCE_INLINE static void calcEditDistanceOneStep(int *editDistanceTable, const int *input,
    335         const int inputSize, const int *output, const int outputLength) {
    336     // TODO: Make sure that editDistance[0 ~ MAX_WORD_LENGTH] is not touched.
    337     // Let dp[i][j] be editDistanceTable[i * (inputSize + 1) + j].
    338     // Assuming that dp[0][0] ... dp[outputLength - 1][inputSize] are already calculated,
    339     // and calculate dp[ouputLength][0] ... dp[outputLength][inputSize].
    340     int *const current = editDistanceTable + outputLength * (inputSize + 1);
    341     const int *const prev = editDistanceTable + (outputLength - 1) * (inputSize + 1);
    342     const int *const prevprev =
    343             outputLength >= 2 ? editDistanceTable + (outputLength - 2) * (inputSize + 1) : 0;
    344     current[0] = outputLength;
    345     const int co = toBaseLowerCase(output[outputLength - 1]);
    346     const int prevCO = outputLength >= 2 ? toBaseLowerCase(output[outputLength - 2]) : 0;
    347     for (int i = 1; i <= inputSize; ++i) {
    348         const int ci = toBaseLowerCase(input[i - 1]);
    349         const int cost = (ci == co) ? 0 : 1;
    350         current[i] = min(current[i - 1] + 1, min(prev[i] + 1, prev[i - 1] + cost));
    351         if (i >= 2 && prevprev && ci == prevCO && co == toBaseLowerCase(input[i - 2])) {
    352             current[i] = min(current[i], prevprev[i - 2] + 1);
    353         }
    354     }
    355 }
    356 
    357 AK_FORCE_INLINE void Correction::addCharToCurrentWord(const int c) {
    358     mWord[mOutputIndex] = c;
    359     const int *primaryInputWord = mProximityInfoState.getPrimaryInputWord();
    360     calcEditDistanceOneStep(mEditDistanceTable, primaryInputWord, mInputSize, mWord,
    361             mOutputIndex + 1);
    362 }
    363 
    364 inline int Correction::getFinalProbabilityInternal(const int probability, int **word,
    365         int *wordLength, const int inputSize) {
    366     const int outputIndex = mTerminalOutputIndex;
    367     const int inputIndex = mTerminalInputIndex;
    368     *wordLength = outputIndex + 1;
    369     *word = mWord;
    370     int finalProbability= Correction::RankingAlgorithm::calculateFinalProbability(
    371             inputIndex, outputIndex, probability, mEditDistanceTable, this, inputSize);
    372     return finalProbability;
    373 }
    374 
    375 } // namespace latinime
    376 #endif // LATINIME_CORRECTION_H
    377