Home | History | Annotate | Download | only in internal
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LATINIME_DIC_NODE_STATE_SCORING_H
     18 #define LATINIME_DIC_NODE_STATE_SCORING_H
     19 
     20 #include <algorithm>
     21 #include <cstdint>
     22 
     23 #include "defines.h"
     24 #include "suggest/core/dictionary/digraph_utils.h"
     25 #include "suggest/core/dictionary/error_type_utils.h"
     26 
     27 namespace latinime {
     28 
     29 class DicNodeStateScoring {
     30  public:
     31     AK_FORCE_INLINE DicNodeStateScoring()
     32             : mDoubleLetterLevel(NOT_A_DOUBLE_LETTER),
     33               mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
     34               mEditCorrectionCount(0), mProximityCorrectionCount(0), mCompletionCount(0),
     35               mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
     36               mRawLength(0.0f), mContainedErrorTypes(ErrorTypeUtils::NOT_AN_ERROR),
     37               mNormalizedCompoundDistanceAfterFirstWord(MAX_VALUE_FOR_WEIGHTING) {
     38     }
     39 
     40     ~DicNodeStateScoring() {}
     41 
     42     void init() {
     43         mEditCorrectionCount = 0;
     44         mProximityCorrectionCount = 0;
     45         mCompletionCount = 0;
     46         mNormalizedCompoundDistance = 0.0f;
     47         mSpatialDistance = 0.0f;
     48         mLanguageDistance = 0.0f;
     49         mRawLength = 0.0f;
     50         mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
     51         mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
     52         mNormalizedCompoundDistanceAfterFirstWord = MAX_VALUE_FOR_WEIGHTING;
     53         mContainedErrorTypes = ErrorTypeUtils::NOT_AN_ERROR;
     54     }
     55 
     56     AK_FORCE_INLINE void initByCopy(const DicNodeStateScoring *const scoring) {
     57         mEditCorrectionCount = scoring->mEditCorrectionCount;
     58         mProximityCorrectionCount = scoring->mProximityCorrectionCount;
     59         mCompletionCount = scoring->mCompletionCount;
     60         mNormalizedCompoundDistance = scoring->mNormalizedCompoundDistance;
     61         mSpatialDistance = scoring->mSpatialDistance;
     62         mLanguageDistance = scoring->mLanguageDistance;
     63         mRawLength = scoring->mRawLength;
     64         mDoubleLetterLevel = scoring->mDoubleLetterLevel;
     65         mDigraphIndex = scoring->mDigraphIndex;
     66         mContainedErrorTypes = scoring->mContainedErrorTypes;
     67         mNormalizedCompoundDistanceAfterFirstWord =
     68                 scoring->mNormalizedCompoundDistanceAfterFirstWord;
     69     }
     70 
     71     void addCost(const float spatialCost, const float languageCost, const bool doNormalization,
     72             const int inputSize, const int totalInputIndex,
     73             const ErrorTypeUtils::ErrorType errorType) {
     74         addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex);
     75         mContainedErrorTypes = mContainedErrorTypes | errorType;
     76         if (ErrorTypeUtils::isEditCorrectionError(errorType)) {
     77             ++mEditCorrectionCount;
     78         }
     79         if (ErrorTypeUtils::isProximityCorrectionError(errorType)) {
     80             ++mProximityCorrectionCount;
     81         }
     82         if (ErrorTypeUtils::isCompletion(errorType)) {
     83             ++mCompletionCount;
     84         }
     85     }
     86 
     87     // Saves the current normalized distance for space-aware gestures.
     88     // See getNormalizedCompoundDistanceAfterFirstWord for details.
     89     void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() {
     90         // We get called here after each word. We only want to store the distance after
     91         // the first word, so if we already have a distance we skip saving -- hence "IfNoneYet"
     92         // in the method name.
     93         if (mNormalizedCompoundDistanceAfterFirstWord >= MAX_VALUE_FOR_WEIGHTING) {
     94             mNormalizedCompoundDistanceAfterFirstWord = getNormalizedCompoundDistance();
     95         }
     96     }
     97 
     98     void addRawLength(const float rawLength) {
     99         mRawLength += rawLength;
    100     }
    101 
    102     float getCompoundDistance() const {
    103         return getCompoundDistance(1.0f);
    104     }
    105 
    106     float getCompoundDistance(const float languageWeight) const {
    107         return mSpatialDistance + mLanguageDistance * languageWeight;
    108     }
    109 
    110     float getNormalizedCompoundDistance() const {
    111         return mNormalizedCompoundDistance;
    112     }
    113 
    114     // For space-aware gestures, we store the normalized distance at the char index
    115     // that ends the first word of the suggestion. We call this the distance after
    116     // first word.
    117     float getNormalizedCompoundDistanceAfterFirstWord() const {
    118         return mNormalizedCompoundDistanceAfterFirstWord;
    119     }
    120 
    121     float getSpatialDistance() const {
    122         return mSpatialDistance;
    123     }
    124 
    125     float getLanguageDistance() const {
    126         return mLanguageDistance;
    127     }
    128 
    129     int16_t getEditCorrectionCount() const {
    130         return mEditCorrectionCount;
    131     }
    132 
    133     int16_t getProximityCorrectionCount() const {
    134         return mProximityCorrectionCount;
    135     }
    136 
    137     int16_t getCompletionCount() const {
    138         return mCompletionCount;
    139     }
    140 
    141     float getRawLength() const {
    142         return mRawLength;
    143     }
    144 
    145     DoubleLetterLevel getDoubleLetterLevel() const {
    146         return mDoubleLetterLevel;
    147     }
    148 
    149     void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) {
    150         switch(doubleLetterLevel) {
    151             case NOT_A_DOUBLE_LETTER:
    152                 break;
    153             case A_DOUBLE_LETTER:
    154                 if (mDoubleLetterLevel != A_STRONG_DOUBLE_LETTER) {
    155                     mDoubleLetterLevel = doubleLetterLevel;
    156                 }
    157                 break;
    158             case A_STRONG_DOUBLE_LETTER:
    159                 mDoubleLetterLevel = doubleLetterLevel;
    160                 break;
    161         }
    162     }
    163 
    164     DigraphUtils::DigraphCodePointIndex getDigraphIndex() const {
    165         return mDigraphIndex;
    166     }
    167 
    168     void advanceDigraphIndex() {
    169         switch(mDigraphIndex) {
    170             case DigraphUtils::NOT_A_DIGRAPH_INDEX:
    171                 mDigraphIndex = DigraphUtils::FIRST_DIGRAPH_CODEPOINT;
    172                 break;
    173             case DigraphUtils::FIRST_DIGRAPH_CODEPOINT:
    174                 mDigraphIndex = DigraphUtils::SECOND_DIGRAPH_CODEPOINT;
    175                 break;
    176             case DigraphUtils::SECOND_DIGRAPH_CODEPOINT:
    177                 mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
    178                 break;
    179         }
    180     }
    181 
    182     ErrorTypeUtils::ErrorType getContainedErrorTypes() const {
    183         return mContainedErrorTypes;
    184     }
    185 
    186  private:
    187     DISALLOW_COPY_AND_ASSIGN(DicNodeStateScoring);
    188 
    189     DoubleLetterLevel mDoubleLetterLevel;
    190     DigraphUtils::DigraphCodePointIndex mDigraphIndex;
    191 
    192     int16_t mEditCorrectionCount;
    193     int16_t mProximityCorrectionCount;
    194     int16_t mCompletionCount;
    195 
    196     float mNormalizedCompoundDistance;
    197     float mSpatialDistance;
    198     float mLanguageDistance;
    199     float mRawLength;
    200     // All accumulated error types so far
    201     ErrorTypeUtils::ErrorType mContainedErrorTypes;
    202     float mNormalizedCompoundDistanceAfterFirstWord;
    203 
    204     AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance,
    205             bool doNormalization, int inputSize, int totalInputIndex) {
    206         mSpatialDistance += spatialDistance;
    207         mLanguageDistance += languageDistance;
    208         if (!doNormalization) {
    209             mNormalizedCompoundDistance = mSpatialDistance + mLanguageDistance;
    210         } else {
    211             mNormalizedCompoundDistance = (mSpatialDistance + mLanguageDistance)
    212                     / static_cast<float>(std::max(1, totalInputIndex));
    213         }
    214     }
    215 };
    216 } // namespace latinime
    217 #endif // LATINIME_DIC_NODE_STATE_SCORING_H
    218