1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LATINIME_TYPING_WEIGHTING_H 18 #define LATINIME_TYPING_WEIGHTING_H 19 20 #include "defines.h" 21 #include "suggest_utils.h" 22 #include "suggest/core/dicnode/dic_node_utils.h" 23 #include "suggest/core/policy/weighting.h" 24 #include "suggest/core/session/dic_traverse_session.h" 25 #include "suggest/policyimpl/typing/scoring_params.h" 26 27 namespace latinime { 28 29 class DicNode; 30 struct DicNode_InputStateG; 31 class MultiBigramMap; 32 33 class TypingWeighting : public Weighting { 34 public: 35 static const TypingWeighting *getInstance() { return &sInstance; } 36 37 protected: 38 float getTerminalSpatialCost(const DicTraverseSession *const traverseSession, 39 const DicNode *const dicNode) const { 40 float cost = 0.0f; 41 if (dicNode->hasMultipleWords()) { 42 cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST; 43 } 44 if (dicNode->getProximityCorrectionCount() > 0) { 45 cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST; 46 } 47 if (dicNode->getEditCorrectionCount() > 0) { 48 cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST; 49 } 50 return cost; 51 } 52 53 float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const { 54 const bool isZeroCostOmission = parentDicNode->isZeroCostOmission(); 55 const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); 56 // If the traversal omitted the first letter then the dicNode should now be on the second. 57 const bool isFirstLetterOmission = dicNode->getDepth() == 2; 58 float cost = 0.0f; 59 if (isZeroCostOmission) { 60 cost = 0.0f; 61 } else if (isFirstLetterOmission) { 62 cost = ScoringParams::OMISSION_COST_FIRST_CHAR; 63 } else { 64 cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR 65 : ScoringParams::OMISSION_COST; 66 } 67 return cost; 68 } 69 70 float getMatchedCost(const DicTraverseSession *const traverseSession, 71 const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const { 72 const int pointIndex = dicNode->getInputIndex(0); 73 // Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on 74 // the keyboard (like accented letters) 75 const float normalizedSquaredLength = traverseSession->getProximityInfoState(0) 76 ->getPointToKeyLength(pointIndex, dicNode->getNodeCodePoint()); 77 const float normalizedDistance = SuggestUtils::getSweetSpotFactor( 78 traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength); 79 const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance; 80 81 const bool isFirstChar = pointIndex == 0; 82 const bool isProximity = isProximityDicNode(traverseSession, dicNode); 83 float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_PROXIMITY_COST 84 : ScoringParams::PROXIMITY_COST) : 0.0f; 85 if (dicNode->getDepth() == 2) { 86 // At the second character of the current word, we check if the first char is uppercase 87 // and the word is a second or later word of a multiple word suggestion. We demote it 88 // if so. 89 const bool isSecondOrLaterWordFirstCharUppercase = 90 dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase(); 91 if (isSecondOrLaterWordFirstCharUppercase) { 92 cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE; 93 } 94 } 95 return weightedDistance + cost; 96 } 97 98 bool isProximityDicNode(const DicTraverseSession *const traverseSession, 99 const DicNode *const dicNode) const { 100 const int pointIndex = dicNode->getInputIndex(0); 101 const int primaryCodePoint = toBaseLowerCase( 102 traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex)); 103 const int dicNodeChar = toBaseLowerCase(dicNode->getNodeCodePoint()); 104 return primaryCodePoint != dicNodeChar; 105 } 106 107 float getTranspositionCost(const DicTraverseSession *const traverseSession, 108 const DicNode *const parentDicNode, const DicNode *const dicNode) const { 109 const int16_t parentPointIndex = parentDicNode->getInputIndex(0); 110 const int prevCodePoint = parentDicNode->getNodeCodePoint(); 111 const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( 112 parentPointIndex + 1, prevCodePoint); 113 const int codePoint = dicNode->getNodeCodePoint(); 114 const float distance2 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( 115 parentPointIndex, codePoint); 116 const float distance = distance1 + distance2; 117 const float weightedLengthDistance = 118 distance * ScoringParams::DISTANCE_WEIGHT_LENGTH; 119 return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance; 120 } 121 122 float getInsertionCost(const DicTraverseSession *const traverseSession, 123 const DicNode *const parentDicNode, const DicNode *const dicNode) const { 124 const int16_t parentPointIndex = parentDicNode->getInputIndex(0); 125 const int prevCodePoint = 126 traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(parentPointIndex); 127 128 const int currentCodePoint = dicNode->getNodeCodePoint(); 129 const bool sameCodePoint = prevCodePoint == currentCodePoint; 130 const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength( 131 parentPointIndex + 1, currentCodePoint); 132 const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH; 133 const bool singleChar = dicNode->getDepth() == 1; 134 const float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f) 135 + (sameCodePoint ? ScoringParams::INSERTION_COST_SAME_CHAR 136 : ScoringParams::INSERTION_COST); 137 return cost + weightedDistance; 138 } 139 140 float getNewWordCost(const DicTraverseSession *const traverseSession, 141 const DicNode *const dicNode) const { 142 return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier(); 143 } 144 145 float getNewWordBigramCost(const DicTraverseSession *const traverseSession, 146 const DicNode *const dicNode, 147 MultiBigramMap *const multiBigramMap) const { 148 return DicNodeUtils::getBigramNodeImprobability(traverseSession->getOffsetDict(), 149 dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; 150 } 151 152 float getCompletionCost(const DicTraverseSession *const traverseSession, 153 const DicNode *const dicNode) const { 154 // The auto completion starts when the input index is same as the input size 155 const bool firstCompletion = dicNode->getInputIndex(0) 156 == traverseSession->getInputSize(); 157 // TODO: Change the cost for the first completion for the gesture? 158 const float cost = firstCompletion ? ScoringParams::COST_FIRST_LOOKAHEAD 159 : ScoringParams::COST_LOOKAHEAD; 160 return cost; 161 } 162 163 float getTerminalLanguageCost(const DicTraverseSession *const traverseSession, 164 const DicNode *const dicNode, const float dicNodeLanguageImprobability) const { 165 const float languageImprobability = (dicNode->isExactMatch()) ? 166 0.0f : dicNodeLanguageImprobability; 167 return languageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; 168 } 169 170 AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const { 171 return false; 172 } 173 174 AK_FORCE_INLINE float getAdditionalProximityCost() const { 175 return ScoringParams::ADDITIONAL_PROXIMITY_COST; 176 } 177 178 AK_FORCE_INLINE float getSubstitutionCost() const { 179 return ScoringParams::SUBSTITUTION_COST; 180 } 181 182 AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession, 183 const DicNode *const dicNode) const { 184 const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD; 185 return cost * traverseSession->getMultiWordCostMultiplier(); 186 } 187 188 ErrorType getErrorType(const CorrectionType correctionType, 189 const DicTraverseSession *const traverseSession, 190 const DicNode *const parentDicNode, const DicNode *const dicNode) const; 191 192 private: 193 DISALLOW_COPY_AND_ASSIGN(TypingWeighting); 194 static const TypingWeighting sInstance; 195 196 TypingWeighting() {} 197 ~TypingWeighting() {} 198 }; 199 } // namespace latinime 200 #endif // LATINIME_TYPING_WEIGHTING_H 201