Home | History | Annotate | Download | only in src
      1 /*
      2  * Copyright (C) 2011 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LATINIME_CORRECTION_H
     18 #define LATINIME_CORRECTION_H
     19 
     20 #include <cassert>
     21 #include <cstring> // for memset()
     22 #include <stdint.h>
     23 
     24 #include "correction_state.h"
     25 #include "defines.h"
     26 #include "proximity_info_state.h"
     27 
     28 namespace latinime {
     29 
     30 class ProximityInfo;
     31 
     32 class Correction {
     33  public:
     34     typedef enum {
     35         TRAVERSE_ALL_ON_TERMINAL,
     36         TRAVERSE_ALL_NOT_ON_TERMINAL,
     37         UNRELATED,
     38         ON_TERMINAL,
     39         NOT_ON_TERMINAL
     40     } CorrectionType;
     41 
     42     Correction()
     43             : mProximityInfo(0), mUseFullEditDistance(false), mDoAutoCompletion(false),
     44               mMaxEditDistance(0), mMaxDepth(0), mInputSize(0), mSpaceProximityPos(0),
     45               mMissingSpacePos(0), mTerminalInputIndex(0), mTerminalOutputIndex(0), mMaxErrors(0),
     46               mTotalTraverseCount(0), mNeedsToTraverseAllNodes(false), mOutputIndex(0),
     47               mInputIndex(0), mEquivalentCharCount(0), mProximityCount(0), mExcessiveCount(0),
     48               mTransposedCount(0), mSkippedCount(0), mTransposedPos(0), mExcessivePos(0),
     49               mSkipPos(0), mLastCharExceeded(false), mMatching(false), mProximityMatching(false),
     50               mAdditionalProximityMatching(false), mExceeding(false), mTransposing(false),
     51               mSkipping(false), mProximityInfoState() {
     52         memset(mWord, 0, sizeof(mWord));
     53         memset(mDistances, 0, sizeof(mDistances));
     54         memset(mEditDistanceTable, 0, sizeof(mEditDistanceTable));
     55         // NOTE: mCorrectionStates is an array of instances.
     56         // No need to initialize it explicitly here.
     57     }
     58 
     59     virtual ~Correction() {}
     60     void resetCorrection();
     61     void initCorrection(
     62             const ProximityInfo *pi, const int inputSize, const int maxWordLength);
     63     void initCorrectionState(const int rootPos, const int childCount, const bool traverseAll);
     64 
     65     // TODO: remove
     66     void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos,
     67             const int spaceProximityPos, const int missingSpacePos, const bool useFullEditDistance,
     68             const bool doAutoCompletion, const int maxErrors);
     69     void checkState();
     70     bool sameAsTyped();
     71     bool initProcessState(const int index);
     72 
     73     int getInputIndex() const;
     74 
     75     bool needsToPrune() const;
     76 
     77     int pushAndGetTotalTraverseCount() {
     78         return ++mTotalTraverseCount;
     79     }
     80 
     81     int getFreqForSplitMultipleWords(
     82             const int *freqArray, const int *wordLengthArray, const int wordCount,
     83             const bool isSpaceProximity, const unsigned short *word);
     84     int getFinalProbability(const int probability, unsigned short **word, int *wordLength);
     85     int getFinalProbabilityForSubQueue(const int probability, unsigned short **word,
     86             int *wordLength, const int inputSize);
     87 
     88     CorrectionType processCharAndCalcState(const int32_t c, const bool isTerminal);
     89 
     90     /////////////////////////
     91     // Tree helper methods
     92     int goDownTree(const int parentIndex, const int childCount, const int firstChildPos);
     93 
     94     inline int getTreeSiblingPos(const int index) const {
     95         return mCorrectionStates[index].mSiblingPos;
     96     }
     97 
     98     inline void setTreeSiblingPos(const int index, const int pos) {
     99         mCorrectionStates[index].mSiblingPos = pos;
    100     }
    101 
    102     inline int getTreeParentIndex(const int index) const {
    103         return mCorrectionStates[index].mParentIndex;
    104     }
    105 
    106     class RankingAlgorithm {
    107      public:
    108         static int calculateFinalProbability(const int inputIndex, const int depth,
    109                 const int probability, int *editDistanceTable, const Correction *correction,
    110                 const int inputSize);
    111         static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
    112                 const int wordCount, const Correction *correction, const bool isSpaceProximity,
    113                 const unsigned short *word);
    114         static float calcNormalizedScore(const unsigned short *before, const int beforeLength,
    115                 const unsigned short *after, const int afterLength, const int score);
    116         static int editDistance(const unsigned short *before,
    117                 const int beforeLength, const unsigned short *after, const int afterLength);
    118      private:
    119         static const int CODE_SPACE = ' ';
    120         static const int MAX_INITIAL_SCORE = 255;
    121     };
    122 
    123     // proximity info state
    124     void initInputParams(const ProximityInfo *proximityInfo, const int32_t *inputCodes,
    125             const int inputSize, const int *xCoordinates, const int *yCoordinates) {
    126         mProximityInfoState.initInputParams(0, MAX_POINT_TO_KEY_LENGTH,
    127                 proximityInfo, inputCodes, inputSize, xCoordinates, yCoordinates, 0, 0, false);
    128     }
    129 
    130     const unsigned short *getPrimaryInputWord() const {
    131         return mProximityInfoState.getPrimaryInputWord();
    132     }
    133 
    134     unsigned short getPrimaryCharAt(const int index) const {
    135         return mProximityInfoState.getPrimaryCharAt(index);
    136     }
    137 
    138  private:
    139     DISALLOW_COPY_AND_ASSIGN(Correction);
    140 
    141     /////////////////////////
    142     // static inline utils //
    143     /////////////////////////
    144     static const int TWO_31ST_DIV_255 = S_INT_MAX / 255;
    145     static inline int capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num) {
    146         return (num < TWO_31ST_DIV_255 ? 255 * num : S_INT_MAX);
    147     }
    148 
    149     static const int TWO_31ST_DIV_2 = S_INT_MAX / 2;
    150     inline static void multiplyIntCapped(const int multiplier, int *base) {
    151         const int temp = *base;
    152         if (temp != S_INT_MAX) {
    153             // Branch if multiplier == 2 for the optimization
    154             if (multiplier < 0) {
    155                 if (DEBUG_DICT) {
    156                     assert(false);
    157                 }
    158                 AKLOGI("--- Invalid multiplier: %d", multiplier);
    159             } else if (multiplier == 0) {
    160                 *base = 0;
    161             } else if (multiplier == 2) {
    162                 *base = TWO_31ST_DIV_2 >= temp ? temp << 1 : S_INT_MAX;
    163             } else {
    164                 // TODO: This overflow check gives a wrong answer when, for example,
    165                 //       temp = 2^16 + 1 and multiplier = 2^17 + 1.
    166                 //       Fix this behavior.
    167                 const int tempRetval = temp * multiplier;
    168                 *base = tempRetval >= temp ? tempRetval : S_INT_MAX;
    169             }
    170         }
    171     }
    172 
    173     inline static int powerIntCapped(const int base, const int n) {
    174         if (n <= 0) return 1;
    175         if (base == 2) {
    176             return n < 31 ? 1 << n : S_INT_MAX;
    177         } else {
    178             int ret = base;
    179             for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret);
    180             return ret;
    181         }
    182     }
    183 
    184     inline static void multiplyRate(const int rate, int *freq) {
    185         if (*freq != S_INT_MAX) {
    186             if (*freq > 1000000) {
    187                 *freq /= 100;
    188                 multiplyIntCapped(rate, freq);
    189             } else {
    190                 multiplyIntCapped(rate, freq);
    191                 *freq /= 100;
    192             }
    193         }
    194     }
    195 
    196     inline int getSpaceProximityPos() const {
    197         return mSpaceProximityPos;
    198     }
    199     inline int getMissingSpacePos() const {
    200         return mMissingSpacePos;
    201     }
    202 
    203     inline int getSkipPos() const {
    204         return mSkipPos;
    205     }
    206 
    207     inline int getExcessivePos() const {
    208         return mExcessivePos;
    209     }
    210 
    211     inline int getTransposedPos() const {
    212         return mTransposedPos;
    213     }
    214 
    215     inline void incrementInputIndex();
    216     inline void incrementOutputIndex();
    217     inline void startToTraverseAllNodes();
    218     inline bool isSingleQuote(const unsigned short c);
    219     inline CorrectionType processSkipChar(
    220             const int32_t c, const bool isTerminal, const bool inputIndexIncremented);
    221     inline CorrectionType processUnrelatedCorrectionType();
    222     inline void addCharToCurrentWord(const int32_t c);
    223     inline int getFinalProbabilityInternal(const int probability, unsigned short **word,
    224             int *wordLength, const int inputSize);
    225 
    226     static const int TYPED_LETTER_MULTIPLIER = 2;
    227     static const int FULL_WORD_MULTIPLIER = 2;
    228     const ProximityInfo *mProximityInfo;
    229 
    230     bool mUseFullEditDistance;
    231     bool mDoAutoCompletion;
    232     int mMaxEditDistance;
    233     int mMaxDepth;
    234     int mInputSize;
    235     int mSpaceProximityPos;
    236     int mMissingSpacePos;
    237     int mTerminalInputIndex;
    238     int mTerminalOutputIndex;
    239     int mMaxErrors;
    240 
    241     uint8_t mTotalTraverseCount;
    242 
    243     // The following arrays are state buffer.
    244     unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];
    245     int mDistances[MAX_WORD_LENGTH_INTERNAL];
    246 
    247     // Edit distance calculation requires a buffer with (N+1)^2 length for the input length N.
    248     // Caveat: Do not create multiple tables per thread as this table eats up RAM a lot.
    249     int mEditDistanceTable[(MAX_WORD_LENGTH_INTERNAL + 1) * (MAX_WORD_LENGTH_INTERNAL + 1)];
    250 
    251     CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL];
    252 
    253     // The following member variables are being used as cache values of the correction state.
    254     bool mNeedsToTraverseAllNodes;
    255     int mOutputIndex;
    256     int mInputIndex;
    257 
    258     int mEquivalentCharCount;
    259     int mProximityCount;
    260     int mExcessiveCount;
    261     int mTransposedCount;
    262     int mSkippedCount;
    263 
    264     int mTransposedPos;
    265     int mExcessivePos;
    266     int mSkipPos;
    267 
    268     bool mLastCharExceeded;
    269 
    270     bool mMatching;
    271     bool mProximityMatching;
    272     bool mAdditionalProximityMatching;
    273     bool mExceeding;
    274     bool mTransposing;
    275     bool mSkipping;
    276     ProximityInfoState mProximityInfoState;
    277 };
    278 } // namespace latinime
    279 #endif // LATINIME_CORRECTION_H
    280