1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LATINIME_TYPING_WEIGHTING_H 18 #define LATINIME_TYPING_WEIGHTING_H 19 20 #include "defines.h" 21 #include "suggest/core/dicnode/dic_node_utils.h" 22 #include "suggest/core/dictionary/error_type_utils.h" 23 #include "suggest/core/layout/touch_position_correction_utils.h" 24 #include "suggest/core/policy/weighting.h" 25 #include "suggest/core/session/dic_traverse_session.h" 26 #include "suggest/policyimpl/typing/scoring_params.h" 27 #include "utils/char_utils.h" 28 29 namespace latinime { 30 31 class DicNode; 32 struct DicNode_InputStateG; 33 class MultiBigramMap; 34 35 class TypingWeighting : public Weighting { 36 public: 37 static const TypingWeighting *getInstance() { return &sInstance; } 38 39 protected: 40 float getTerminalSpatialCost(const DicTraverseSession *const traverseSession, 41 const DicNode *const dicNode) const { 42 float cost = 0.0f; 43 if (dicNode->hasMultipleWords()) { 44 cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST; 45 } 46 if (dicNode->getProximityCorrectionCount() > 0) { 47 cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST; 48 } 49 if (dicNode->getEditCorrectionCount() > 0) { 50 cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST; 51 } 52 return cost; 53 } 54 55 float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const { 56 const bool isZeroCostOmission = parentDicNode->isZeroCostOmission(); 57 const bool isIntentionalOmission = parentDicNode->canBeIntentionalOmission(); 58 const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); 59 // If the traversal omitted the first letter then the dicNode should now be on the second. 60 const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2; 61 float cost = 0.0f; 62 if (isZeroCostOmission) { 63 cost = 0.0f; 64 } else if (isIntentionalOmission) { 65 cost = ScoringParams::INTENTIONAL_OMISSION_COST; 66 } else if (isFirstLetterOmission) { 67 cost = ScoringParams::OMISSION_COST_FIRST_CHAR; 68 } else { 69 cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR 70 : ScoringParams::OMISSION_COST; 71 } 72 return cost; 73 } 74 75 float getMatchedCost(const DicTraverseSession *const traverseSession, 76 const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const { 77 const int pointIndex = dicNode->getInputIndex(0); 78 const float normalizedSquaredLength = traverseSession->getProximityInfoState(0) 79 ->getPointToKeyLength(pointIndex, 80 CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint())); 81 const float normalizedDistance = TouchPositionCorrectionUtils::getSweetSpotFactor( 82 traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength); 83 const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance; 84 85 const bool isFirstChar = pointIndex == 0; 86 const bool isProximity = isProximityDicNode(traverseSession, dicNode); 87 float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_CHAR_PROXIMITY_COST 88 : ScoringParams::PROXIMITY_COST) : 0.0f; 89 if (isProximity && dicNode->getProximityCorrectionCount() == 0) { 90 cost += ScoringParams::FIRST_PROXIMITY_COST; 91 } 92 if (dicNode->getNodeCodePointCount() == 2) { 93 // At the second character of the current word, we check if the first char is uppercase 94 // and the word is a second or later word of a multiple word suggestion. We demote it 95 // if so. 96 const bool isSecondOrLaterWordFirstCharUppercase = 97 dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase(); 98 if (isSecondOrLaterWordFirstCharUppercase) { 99 cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE; 100 } 101 } 102 return weightedDistance + cost; 103 } 104 105 bool isProximityDicNode(const DicTraverseSession *const traverseSession, 106 const DicNode *const dicNode) const { 107 const int pointIndex = dicNode->getInputIndex(0); 108 const int primaryCodePoint = CharUtils::toBaseLowerCase( 109 traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex)); 110 const int dicNodeChar = CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()); 111 return primaryCodePoint != dicNodeChar; 112 } 113 114 float getTranspositionCost(const DicTraverseSession *const traverseSession, 115 const DicNode *const parentDicNode, const DicNode *const dicNode) const { 116 const int16_t parentPointIndex = parentDicNode->getInputIndex(0); 117 const int prevCodePoint = parentDicNode->getNodeCodePoint(); 118 const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( 119 parentPointIndex + 1, CharUtils::toBaseLowerCase(prevCodePoint)); 120 const int codePoint = dicNode->getNodeCodePoint(); 121 const float distance2 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( 122 parentPointIndex, CharUtils::toBaseLowerCase(codePoint)); 123 const float distance = distance1 + distance2; 124 const float weightedLengthDistance = 125 distance * ScoringParams::DISTANCE_WEIGHT_LENGTH; 126 return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance; 127 } 128 129 float getInsertionCost(const DicTraverseSession *const traverseSession, 130 const DicNode *const parentDicNode, const DicNode *const dicNode) const { 131 const int16_t insertedPointIndex = parentDicNode->getInputIndex(0); 132 const int prevCodePoint = traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt( 133 insertedPointIndex); 134 const int currentCodePoint = dicNode->getNodeCodePoint(); 135 const bool sameCodePoint = prevCodePoint == currentCodePoint; 136 const bool existsAdjacentProximityChars = traverseSession->getProximityInfoState(0) 137 ->existsAdjacentProximityChars(insertedPointIndex); 138 const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength( 139 insertedPointIndex + 1, CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint())); 140 const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH; 141 const bool singleChar = dicNode->getNodeCodePointCount() == 1; 142 float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f); 143 if (sameCodePoint) { 144 cost += ScoringParams::INSERTION_COST_SAME_CHAR; 145 } else if (existsAdjacentProximityChars) { 146 cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR; 147 } else { 148 cost += ScoringParams::INSERTION_COST; 149 } 150 return cost + weightedDistance; 151 } 152 153 float getSpaceOmissionCost(const DicTraverseSession *const traverseSession, 154 const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const { 155 const float cost = ScoringParams::SPACE_OMISSION_COST; 156 return cost * traverseSession->getMultiWordCostMultiplier(); 157 } 158 159 float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession, 160 const DicNode *const dicNode, 161 MultiBigramMap *const multiBigramMap) const { 162 return DicNodeUtils::getBigramNodeImprobability( 163 traverseSession->getDictionaryStructurePolicy(), 164 dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; 165 } 166 167 float getCompletionCost(const DicTraverseSession *const traverseSession, 168 const DicNode *const dicNode) const { 169 // The auto completion starts when the input index is same as the input size 170 const bool firstCompletion = dicNode->getInputIndex(0) 171 == traverseSession->getInputSize(); 172 // TODO: Change the cost for the first completion for the gesture? 173 const float cost = firstCompletion ? ScoringParams::COST_FIRST_COMPLETION 174 : ScoringParams::COST_COMPLETION; 175 return cost; 176 } 177 178 float getTerminalLanguageCost(const DicTraverseSession *const traverseSession, 179 const DicNode *const dicNode, const float dicNodeLanguageImprobability) const { 180 return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; 181 } 182 183 float getTerminalInsertionCost(const DicTraverseSession *const traverseSession, 184 const DicNode *const dicNode) const { 185 const int inputIndex = dicNode->getInputIndex(0); 186 const int inputSize = traverseSession->getInputSize(); 187 ASSERT(inputIndex < inputSize); 188 // TODO: Implement more efficient logic 189 return ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex); 190 } 191 192 AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const { 193 return false; 194 } 195 196 AK_FORCE_INLINE float getAdditionalProximityCost() const { 197 return ScoringParams::ADDITIONAL_PROXIMITY_COST; 198 } 199 200 AK_FORCE_INLINE float getSubstitutionCost() const { 201 return ScoringParams::SUBSTITUTION_COST; 202 } 203 204 AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession, 205 const DicNode *const dicNode) const { 206 const int inputIndex = dicNode->getInputIndex(0); 207 const float distanceToSpaceKey = traverseSession->getProximityInfoState(0) 208 ->getPointToKeyLength(inputIndex, KEYCODE_SPACE); 209 const float cost = ScoringParams::SPACE_SUBSTITUTION_COST * distanceToSpaceKey; 210 return cost * traverseSession->getMultiWordCostMultiplier(); 211 } 212 213 ErrorTypeUtils::ErrorType getErrorType(const CorrectionType correctionType, 214 const DicTraverseSession *const traverseSession, 215 const DicNode *const parentDicNode, const DicNode *const dicNode) const; 216 217 private: 218 DISALLOW_COPY_AND_ASSIGN(TypingWeighting); 219 static const TypingWeighting sInstance; 220 221 TypingWeighting() {} 222 ~TypingWeighting() {} 223 }; 224 } // namespace latinime 225 #endif // LATINIME_TYPING_WEIGHTING_H 226