Home | History | Annotate | Download | only in internal
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LATINIME_DIC_NODE_STATE_SCORING_H
     18 #define LATINIME_DIC_NODE_STATE_SCORING_H
     19 
     20 #include <algorithm>
     21 #include <cstdint>
     22 
     23 #include "defines.h"
     24 #include "suggest/core/dictionary/digraph_utils.h"
     25 #include "suggest/core/dictionary/error_type_utils.h"
     26 
     27 namespace latinime {
     28 
     29 class DicNodeStateScoring {
     30  public:
     31     AK_FORCE_INLINE DicNodeStateScoring()
     32             : mDoubleLetterLevel(NOT_A_DOUBLE_LETTER),
     33               mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
     34               mEditCorrectionCount(0), mProximityCorrectionCount(0), mCompletionCount(0),
     35               mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
     36               mRawLength(0.0f), mContainedErrorTypes(ErrorTypeUtils::NOT_AN_ERROR),
     37               mNormalizedCompoundDistanceAfterFirstWord(MAX_VALUE_FOR_WEIGHTING) {
     38     }
     39 
     40     ~DicNodeStateScoring() {}
     41 
     42     void init() {
     43         mEditCorrectionCount = 0;
     44         mProximityCorrectionCount = 0;
     45         mCompletionCount = 0;
     46         mNormalizedCompoundDistance = 0.0f;
     47         mSpatialDistance = 0.0f;
     48         mLanguageDistance = 0.0f;
     49         mRawLength = 0.0f;
     50         mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
     51         mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
     52         mNormalizedCompoundDistanceAfterFirstWord = MAX_VALUE_FOR_WEIGHTING;
     53         mContainedErrorTypes = ErrorTypeUtils::NOT_AN_ERROR;
     54     }
     55 
     56     AK_FORCE_INLINE void initByCopy(const DicNodeStateScoring *const scoring) {
     57         mEditCorrectionCount = scoring->mEditCorrectionCount;
     58         mProximityCorrectionCount = scoring->mProximityCorrectionCount;
     59         mCompletionCount = scoring->mCompletionCount;
     60         mNormalizedCompoundDistance = scoring->mNormalizedCompoundDistance;
     61         mSpatialDistance = scoring->mSpatialDistance;
     62         mLanguageDistance = scoring->mLanguageDistance;
     63         mRawLength = scoring->mRawLength;
     64         mDoubleLetterLevel = scoring->mDoubleLetterLevel;
     65         mDigraphIndex = scoring->mDigraphIndex;
     66         mContainedErrorTypes = scoring->mContainedErrorTypes;
     67         mNormalizedCompoundDistanceAfterFirstWord =
     68                 scoring->mNormalizedCompoundDistanceAfterFirstWord;
     69     }
     70 
     71     void addCost(const float spatialCost, const float languageCost, const bool doNormalization,
     72             const int inputSize, const int totalInputIndex,
     73             const ErrorTypeUtils::ErrorType errorType) {
     74         addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex);
     75         mContainedErrorTypes = mContainedErrorTypes | errorType;
     76         if (ErrorTypeUtils::isEditCorrectionError(errorType)) {
     77             ++mEditCorrectionCount;
     78         }
     79         if (ErrorTypeUtils::isProximityCorrectionError(errorType)) {
     80             ++mProximityCorrectionCount;
     81         }
     82         if (ErrorTypeUtils::isCompletion(errorType)) {
     83             ++mCompletionCount;
     84         }
     85     }
     86 
     87     // Saves the current normalized distance for space-aware gestures.
     88     // See getNormalizedCompoundDistanceAfterFirstWord for details.
     89     void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() {
     90         // We get called here after each word. We only want to store the distance after
     91         // the first word, so if we already have a distance we skip saving -- hence "IfNoneYet"
     92         // in the method name.
     93         if (mNormalizedCompoundDistanceAfterFirstWord >= MAX_VALUE_FOR_WEIGHTING) {
     94             mNormalizedCompoundDistanceAfterFirstWord = getNormalizedCompoundDistance();
     95         }
     96     }
     97 
     98     void addRawLength(const float rawLength) {
     99         mRawLength += rawLength;
    100     }
    101 
    102     float getCompoundDistance() const {
    103         return getCompoundDistance(1.0f);
    104     }
    105 
    106     float getCompoundDistance(
    107             const float weightOfLangModelVsSpatialModel) const {
    108         return mSpatialDistance
    109                 + mLanguageDistance * weightOfLangModelVsSpatialModel;
    110     }
    111 
    112     float getNormalizedCompoundDistance() const {
    113         return mNormalizedCompoundDistance;
    114     }
    115 
    116     // For space-aware gestures, we store the normalized distance at the char index
    117     // that ends the first word of the suggestion. We call this the distance after
    118     // first word.
    119     float getNormalizedCompoundDistanceAfterFirstWord() const {
    120         return mNormalizedCompoundDistanceAfterFirstWord;
    121     }
    122 
    123     float getSpatialDistance() const {
    124         return mSpatialDistance;
    125     }
    126 
    127     float getLanguageDistance() const {
    128         return mLanguageDistance;
    129     }
    130 
    131     int16_t getEditCorrectionCount() const {
    132         return mEditCorrectionCount;
    133     }
    134 
    135     int16_t getProximityCorrectionCount() const {
    136         return mProximityCorrectionCount;
    137     }
    138 
    139     int16_t getCompletionCount() const {
    140         return mCompletionCount;
    141     }
    142 
    143     float getRawLength() const {
    144         return mRawLength;
    145     }
    146 
    147     DoubleLetterLevel getDoubleLetterLevel() const {
    148         return mDoubleLetterLevel;
    149     }
    150 
    151     void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) {
    152         switch(doubleLetterLevel) {
    153             case NOT_A_DOUBLE_LETTER:
    154                 break;
    155             case A_DOUBLE_LETTER:
    156                 if (mDoubleLetterLevel != A_STRONG_DOUBLE_LETTER) {
    157                     mDoubleLetterLevel = doubleLetterLevel;
    158                 }
    159                 break;
    160             case A_STRONG_DOUBLE_LETTER:
    161                 mDoubleLetterLevel = doubleLetterLevel;
    162                 break;
    163         }
    164     }
    165 
    166     DigraphUtils::DigraphCodePointIndex getDigraphIndex() const {
    167         return mDigraphIndex;
    168     }
    169 
    170     void advanceDigraphIndex() {
    171         switch(mDigraphIndex) {
    172             case DigraphUtils::NOT_A_DIGRAPH_INDEX:
    173                 mDigraphIndex = DigraphUtils::FIRST_DIGRAPH_CODEPOINT;
    174                 break;
    175             case DigraphUtils::FIRST_DIGRAPH_CODEPOINT:
    176                 mDigraphIndex = DigraphUtils::SECOND_DIGRAPH_CODEPOINT;
    177                 break;
    178             case DigraphUtils::SECOND_DIGRAPH_CODEPOINT:
    179                 mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
    180                 break;
    181         }
    182     }
    183 
    184     ErrorTypeUtils::ErrorType getContainedErrorTypes() const {
    185         return mContainedErrorTypes;
    186     }
    187 
    188  private:
    189     DISALLOW_COPY_AND_ASSIGN(DicNodeStateScoring);
    190 
    191     DoubleLetterLevel mDoubleLetterLevel;
    192     DigraphUtils::DigraphCodePointIndex mDigraphIndex;
    193 
    194     int16_t mEditCorrectionCount;
    195     int16_t mProximityCorrectionCount;
    196     int16_t mCompletionCount;
    197 
    198     float mNormalizedCompoundDistance;
    199     float mSpatialDistance;
    200     float mLanguageDistance;
    201     float mRawLength;
    202     // All accumulated error types so far
    203     ErrorTypeUtils::ErrorType mContainedErrorTypes;
    204     float mNormalizedCompoundDistanceAfterFirstWord;
    205 
    206     AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance,
    207             bool doNormalization, int inputSize, int totalInputIndex) {
    208         mSpatialDistance += spatialDistance;
    209         mLanguageDistance += languageDistance;
    210         if (!doNormalization) {
    211             mNormalizedCompoundDistance = mSpatialDistance + mLanguageDistance;
    212         } else {
    213             mNormalizedCompoundDistance = (mSpatialDistance + mLanguageDistance)
    214                     / static_cast<float>(std::max(1, totalInputIndex));
    215         }
    216     }
    217 };
    218 } // namespace latinime
    219 #endif // LATINIME_DIC_NODE_STATE_SCORING_H
    220