Home | History | Annotate | Download | only in typing
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LATINIME_TYPING_WEIGHTING_H
     18 #define LATINIME_TYPING_WEIGHTING_H
     19 
     20 #include "defines.h"
     21 #include "suggest/core/dicnode/dic_node_utils.h"
     22 #include "suggest/core/dictionary/error_type_utils.h"
     23 #include "suggest/core/layout/touch_position_correction_utils.h"
     24 #include "suggest/core/policy/weighting.h"
     25 #include "suggest/core/session/dic_traverse_session.h"
     26 #include "suggest/policyimpl/typing/scoring_params.h"
     27 #include "utils/char_utils.h"
     28 
     29 namespace latinime {
     30 
     31 class DicNode;
     32 struct DicNode_InputStateG;
     33 class MultiBigramMap;
     34 
     35 class TypingWeighting : public Weighting {
     36  public:
     37     static const TypingWeighting *getInstance() { return &sInstance; }
     38 
     39  protected:
     40     float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
     41             const DicNode *const dicNode) const {
     42         float cost = 0.0f;
     43         if (dicNode->hasMultipleWords()) {
     44             cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST;
     45         }
     46         if (dicNode->getProximityCorrectionCount() > 0) {
     47             cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST;
     48         }
     49         if (dicNode->getEditCorrectionCount() > 0) {
     50             cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST;
     51         }
     52         return cost;
     53     }
     54 
     55     float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const {
     56         const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
     57         const bool isIntentionalOmission = parentDicNode->canBeIntentionalOmission();
     58         const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
     59         // If the traversal omitted the first letter then the dicNode should now be on the second.
     60         const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2;
     61         float cost = 0.0f;
     62         if (isZeroCostOmission) {
     63             cost = 0.0f;
     64         } else if (isIntentionalOmission) {
     65             cost = ScoringParams::INTENTIONAL_OMISSION_COST;
     66         } else if (isFirstLetterOmission) {
     67             cost = ScoringParams::OMISSION_COST_FIRST_CHAR;
     68         } else {
     69             cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR
     70                     : ScoringParams::OMISSION_COST;
     71         }
     72         return cost;
     73     }
     74 
     75     float getMatchedCost(const DicTraverseSession *const traverseSession,
     76             const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
     77         const int pointIndex = dicNode->getInputIndex(0);
     78         const float normalizedSquaredLength = traverseSession->getProximityInfoState(0)
     79                 ->getPointToKeyLength(pointIndex,
     80                         CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
     81         const float normalizedDistance = TouchPositionCorrectionUtils::getSweetSpotFactor(
     82                 traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength);
     83         const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance;
     84 
     85         const bool isFirstChar = pointIndex == 0;
     86         const bool isProximity = isProximityDicNode(traverseSession, dicNode);
     87         float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_CHAR_PROXIMITY_COST
     88                 : ScoringParams::PROXIMITY_COST) : 0.0f;
     89         if (isProximity && dicNode->getProximityCorrectionCount() == 0) {
     90             cost += ScoringParams::FIRST_PROXIMITY_COST;
     91         }
     92         if (dicNode->getNodeCodePointCount() == 2) {
     93             // At the second character of the current word, we check if the first char is uppercase
     94             // and the word is a second or later word of a multiple word suggestion. We demote it
     95             // if so.
     96             const bool isSecondOrLaterWordFirstCharUppercase =
     97                     dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase();
     98             if (isSecondOrLaterWordFirstCharUppercase) {
     99                 cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE;
    100             }
    101         }
    102         return weightedDistance + cost;
    103     }
    104 
    105     bool isProximityDicNode(const DicTraverseSession *const traverseSession,
    106             const DicNode *const dicNode) const {
    107         const int pointIndex = dicNode->getInputIndex(0);
    108         const int primaryCodePoint = CharUtils::toBaseLowerCase(
    109                 traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex));
    110         const int dicNodeChar = CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint());
    111         return primaryCodePoint != dicNodeChar;
    112     }
    113 
    114     float getTranspositionCost(const DicTraverseSession *const traverseSession,
    115             const DicNode *const parentDicNode, const DicNode *const dicNode) const {
    116         const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
    117         const int prevCodePoint = parentDicNode->getNodeCodePoint();
    118         const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
    119                 parentPointIndex + 1, CharUtils::toBaseLowerCase(prevCodePoint));
    120         const int codePoint = dicNode->getNodeCodePoint();
    121         const float distance2 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
    122                 parentPointIndex, CharUtils::toBaseLowerCase(codePoint));
    123         const float distance = distance1 + distance2;
    124         const float weightedLengthDistance =
    125                 distance * ScoringParams::DISTANCE_WEIGHT_LENGTH;
    126         return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance;
    127     }
    128 
    129     float getInsertionCost(const DicTraverseSession *const traverseSession,
    130             const DicNode *const parentDicNode, const DicNode *const dicNode) const {
    131         const int16_t insertedPointIndex = parentDicNode->getInputIndex(0);
    132         const int prevCodePoint = traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(
    133                 insertedPointIndex);
    134         const int currentCodePoint = dicNode->getNodeCodePoint();
    135         const bool sameCodePoint = prevCodePoint == currentCodePoint;
    136         const bool existsAdjacentProximityChars = traverseSession->getProximityInfoState(0)
    137                 ->existsAdjacentProximityChars(insertedPointIndex);
    138         const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
    139                 insertedPointIndex + 1, CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
    140         const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH;
    141         const bool singleChar = dicNode->getNodeCodePointCount() == 1;
    142         float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f);
    143         if (sameCodePoint) {
    144             cost += ScoringParams::INSERTION_COST_SAME_CHAR;
    145         } else if (existsAdjacentProximityChars) {
    146             cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR;
    147         } else {
    148             cost += ScoringParams::INSERTION_COST;
    149         }
    150         return cost + weightedDistance;
    151     }
    152 
    153     float getSpaceOmissionCost(const DicTraverseSession *const traverseSession,
    154             const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
    155         const float cost = ScoringParams::SPACE_OMISSION_COST;
    156         return cost * traverseSession->getMultiWordCostMultiplier();
    157     }
    158 
    159     float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession,
    160             const DicNode *const dicNode,
    161             MultiBigramMap *const multiBigramMap) const {
    162         return DicNodeUtils::getBigramNodeImprobability(
    163                 traverseSession->getDictionaryStructurePolicy(),
    164                 dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
    165     }
    166 
    167     float getCompletionCost(const DicTraverseSession *const traverseSession,
    168             const DicNode *const dicNode) const {
    169         // The auto completion starts when the input index is same as the input size
    170         const bool firstCompletion = dicNode->getInputIndex(0)
    171                 == traverseSession->getInputSize();
    172         // TODO: Change the cost for the first completion for the gesture?
    173         const float cost = firstCompletion ? ScoringParams::COST_FIRST_COMPLETION
    174                 : ScoringParams::COST_COMPLETION;
    175         return cost;
    176     }
    177 
    178     float getTerminalLanguageCost(const DicTraverseSession *const traverseSession,
    179             const DicNode *const dicNode, const float dicNodeLanguageImprobability) const {
    180         return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
    181     }
    182 
    183     float getTerminalInsertionCost(const DicTraverseSession *const traverseSession,
    184             const DicNode *const dicNode) const {
    185         const int inputIndex = dicNode->getInputIndex(0);
    186         const int inputSize = traverseSession->getInputSize();
    187         ASSERT(inputIndex < inputSize);
    188         // TODO: Implement more efficient logic
    189         return  ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex);
    190     }
    191 
    192     AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const {
    193         return false;
    194     }
    195 
    196     AK_FORCE_INLINE float getAdditionalProximityCost() const {
    197         return ScoringParams::ADDITIONAL_PROXIMITY_COST;
    198     }
    199 
    200     AK_FORCE_INLINE float getSubstitutionCost() const {
    201         return ScoringParams::SUBSTITUTION_COST;
    202     }
    203 
    204     AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
    205             const DicNode *const dicNode) const {
    206         const int inputIndex = dicNode->getInputIndex(0);
    207         const float distanceToSpaceKey = traverseSession->getProximityInfoState(0)
    208                 ->getPointToKeyLength(inputIndex, KEYCODE_SPACE);
    209         const float cost = ScoringParams::SPACE_SUBSTITUTION_COST * distanceToSpaceKey;
    210         return cost * traverseSession->getMultiWordCostMultiplier();
    211     }
    212 
    213     ErrorTypeUtils::ErrorType getErrorType(const CorrectionType correctionType,
    214             const DicTraverseSession *const traverseSession,
    215             const DicNode *const parentDicNode, const DicNode *const dicNode) const;
    216 
    217  private:
    218     DISALLOW_COPY_AND_ASSIGN(TypingWeighting);
    219     static const TypingWeighting sInstance;
    220 
    221     TypingWeighting() {}
    222     ~TypingWeighting() {}
    223 };
    224 } // namespace latinime
    225 #endif // LATINIME_TYPING_WEIGHTING_H
    226