1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LATINIME_TYPING_WEIGHTING_H 18 #define LATINIME_TYPING_WEIGHTING_H 19 20 #include "defines.h" 21 #include "suggest/core/dicnode/dic_node_utils.h" 22 #include "suggest/core/layout/touch_position_correction_utils.h" 23 #include "suggest/core/policy/weighting.h" 24 #include "suggest/core/session/dic_traverse_session.h" 25 #include "suggest/policyimpl/typing/scoring_params.h" 26 #include "utils/char_utils.h" 27 28 namespace latinime { 29 30 class DicNode; 31 struct DicNode_InputStateG; 32 class MultiBigramMap; 33 34 class TypingWeighting : public Weighting { 35 public: 36 static const TypingWeighting *getInstance() { return &sInstance; } 37 38 protected: 39 float getTerminalSpatialCost(const DicTraverseSession *const traverseSession, 40 const DicNode *const dicNode) const { 41 float cost = 0.0f; 42 if (dicNode->hasMultipleWords()) { 43 cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST; 44 } 45 if (dicNode->getProximityCorrectionCount() > 0) { 46 cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST; 47 } 48 if (dicNode->getEditCorrectionCount() > 0) { 49 cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST; 50 } 51 return cost; 52 } 53 54 float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const { 55 const bool isZeroCostOmission = parentDicNode->isZeroCostOmission(); 56 const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); 57 // If the traversal omitted the first letter then the dicNode should now be on the second. 58 const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2; 59 float cost = 0.0f; 60 if (isZeroCostOmission) { 61 cost = 0.0f; 62 } else if (isFirstLetterOmission) { 63 cost = ScoringParams::OMISSION_COST_FIRST_CHAR; 64 } else { 65 cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR 66 : ScoringParams::OMISSION_COST; 67 } 68 return cost; 69 } 70 71 float getMatchedCost(const DicTraverseSession *const traverseSession, 72 const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const { 73 const int pointIndex = dicNode->getInputIndex(0); 74 // Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on 75 // the keyboard (like accented letters) 76 const float normalizedSquaredLength = traverseSession->getProximityInfoState(0) 77 ->getPointToKeyLength(pointIndex, 78 CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint())); 79 const float normalizedDistance = TouchPositionCorrectionUtils::getSweetSpotFactor( 80 traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength); 81 const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance; 82 83 const bool isFirstChar = pointIndex == 0; 84 const bool isProximity = isProximityDicNode(traverseSession, dicNode); 85 float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_CHAR_PROXIMITY_COST 86 : ScoringParams::PROXIMITY_COST) : 0.0f; 87 if (isProximity && dicNode->getProximityCorrectionCount() == 0) { 88 cost += ScoringParams::FIRST_PROXIMITY_COST; 89 } 90 if (dicNode->getNodeCodePointCount() == 2) { 91 // At the second character of the current word, we check if the first char is uppercase 92 // and the word is a second or later word of a multiple word suggestion. We demote it 93 // if so. 94 const bool isSecondOrLaterWordFirstCharUppercase = 95 dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase(); 96 if (isSecondOrLaterWordFirstCharUppercase) { 97 cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE; 98 } 99 } 100 return weightedDistance + cost; 101 } 102 103 bool isProximityDicNode(const DicTraverseSession *const traverseSession, 104 const DicNode *const dicNode) const { 105 const int pointIndex = dicNode->getInputIndex(0); 106 const int primaryCodePoint = CharUtils::toBaseLowerCase( 107 traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex)); 108 const int dicNodeChar = CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()); 109 return primaryCodePoint != dicNodeChar; 110 } 111 112 float getTranspositionCost(const DicTraverseSession *const traverseSession, 113 const DicNode *const parentDicNode, const DicNode *const dicNode) const { 114 const int16_t parentPointIndex = parentDicNode->getInputIndex(0); 115 const int prevCodePoint = parentDicNode->getNodeCodePoint(); 116 const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( 117 parentPointIndex + 1, CharUtils::toBaseLowerCase(prevCodePoint)); 118 const int codePoint = dicNode->getNodeCodePoint(); 119 const float distance2 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( 120 parentPointIndex, CharUtils::toBaseLowerCase(codePoint)); 121 const float distance = distance1 + distance2; 122 const float weightedLengthDistance = 123 distance * ScoringParams::DISTANCE_WEIGHT_LENGTH; 124 return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance; 125 } 126 127 float getInsertionCost(const DicTraverseSession *const traverseSession, 128 const DicNode *const parentDicNode, const DicNode *const dicNode) const { 129 const int16_t insertedPointIndex = parentDicNode->getInputIndex(0); 130 const int prevCodePoint = traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt( 131 insertedPointIndex); 132 const int currentCodePoint = dicNode->getNodeCodePoint(); 133 const bool sameCodePoint = prevCodePoint == currentCodePoint; 134 const bool existsAdjacentProximityChars = traverseSession->getProximityInfoState(0) 135 ->existsAdjacentProximityChars(insertedPointIndex); 136 const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength( 137 insertedPointIndex + 1, CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint())); 138 const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH; 139 const bool singleChar = dicNode->getNodeCodePointCount() == 1; 140 float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f); 141 if (sameCodePoint) { 142 cost += ScoringParams::INSERTION_COST_SAME_CHAR; 143 } else if (existsAdjacentProximityChars) { 144 cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR; 145 } else { 146 cost += ScoringParams::INSERTION_COST; 147 } 148 return cost + weightedDistance; 149 } 150 151 float getNewWordSpatialCost(const DicTraverseSession *const traverseSession, 152 const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const { 153 return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier(); 154 } 155 156 float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession, 157 const DicNode *const dicNode, 158 MultiBigramMap *const multiBigramMap) const { 159 return DicNodeUtils::getBigramNodeImprobability( 160 traverseSession->getDictionaryStructurePolicy(), 161 dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; 162 } 163 164 float getCompletionCost(const DicTraverseSession *const traverseSession, 165 const DicNode *const dicNode) const { 166 // The auto completion starts when the input index is same as the input size 167 const bool firstCompletion = dicNode->getInputIndex(0) 168 == traverseSession->getInputSize(); 169 // TODO: Change the cost for the first completion for the gesture? 170 const float cost = firstCompletion ? ScoringParams::COST_FIRST_LOOKAHEAD 171 : ScoringParams::COST_LOOKAHEAD; 172 return cost; 173 } 174 175 float getTerminalLanguageCost(const DicTraverseSession *const traverseSession, 176 const DicNode *const dicNode, const float dicNodeLanguageImprobability) const { 177 return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; 178 } 179 180 float getTerminalInsertionCost(const DicTraverseSession *const traverseSession, 181 const DicNode *const dicNode) const { 182 const int inputIndex = dicNode->getInputIndex(0); 183 const int inputSize = traverseSession->getInputSize(); 184 ASSERT(inputIndex < inputSize); 185 // TODO: Implement more efficient logic 186 return ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex); 187 } 188 189 AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const { 190 return false; 191 } 192 193 AK_FORCE_INLINE float getAdditionalProximityCost() const { 194 return ScoringParams::ADDITIONAL_PROXIMITY_COST; 195 } 196 197 AK_FORCE_INLINE float getSubstitutionCost() const { 198 return ScoringParams::SUBSTITUTION_COST; 199 } 200 201 AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession, 202 const DicNode *const dicNode) const { 203 const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD; 204 return cost * traverseSession->getMultiWordCostMultiplier(); 205 } 206 207 ErrorType getErrorType(const CorrectionType correctionType, 208 const DicTraverseSession *const traverseSession, 209 const DicNode *const parentDicNode, const DicNode *const dicNode) const; 210 211 private: 212 DISALLOW_COPY_AND_ASSIGN(TypingWeighting); 213 static const TypingWeighting sInstance; 214 215 TypingWeighting() {} 216 ~TypingWeighting() {} 217 }; 218 } // namespace latinime 219 #endif // LATINIME_TYPING_WEIGHTING_H 220