1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LATINIME_TYPING_WEIGHTING_H 18 #define LATINIME_TYPING_WEIGHTING_H 19 20 #include "defines.h" 21 #include "suggest/core/dicnode/dic_node_utils.h" 22 #include "suggest/core/dictionary/error_type_utils.h" 23 #include "suggest/core/layout/touch_position_correction_utils.h" 24 #include "suggest/core/policy/weighting.h" 25 #include "suggest/core/session/dic_traverse_session.h" 26 #include "suggest/policyimpl/typing/scoring_params.h" 27 #include "utils/char_utils.h" 28 29 namespace latinime { 30 31 class DicNode; 32 struct DicNode_InputStateG; 33 class MultiBigramMap; 34 35 class TypingWeighting : public Weighting { 36 public: 37 static const TypingWeighting *getInstance() { return &sInstance; } 38 39 protected: 40 float getTerminalSpatialCost(const DicTraverseSession *const traverseSession, 41 const DicNode *const dicNode) const { 42 float cost = 0.0f; 43 if (dicNode->hasMultipleWords()) { 44 cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST; 45 } 46 if (dicNode->getProximityCorrectionCount() > 0) { 47 cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST; 48 } 49 if (dicNode->getEditCorrectionCount() > 0) { 50 cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST; 51 } 52 return cost; 53 } 54 55 float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const { 56 const bool isZeroCostOmission = parentDicNode->isZeroCostOmission(); 57 const bool isIntentionalOmission = parentDicNode->canBeIntentionalOmission(); 58 const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); 59 // If the traversal omitted the first letter then the dicNode should now be on the second. 60 const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2; 61 float cost = 0.0f; 62 if (isZeroCostOmission) { 63 cost = 0.0f; 64 } else if (isIntentionalOmission) { 65 cost = ScoringParams::INTENTIONAL_OMISSION_COST; 66 } else if (isFirstLetterOmission) { 67 cost = ScoringParams::OMISSION_COST_FIRST_CHAR; 68 } else { 69 cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR 70 : ScoringParams::OMISSION_COST; 71 } 72 return cost; 73 } 74 75 float getMatchedCost(const DicTraverseSession *const traverseSession, 76 const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const { 77 const int pointIndex = dicNode->getInputIndex(0); 78 const float normalizedSquaredLength = traverseSession->getProximityInfoState(0) 79 ->getPointToKeyLength(pointIndex, 80 CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint())); 81 const float normalizedDistance = TouchPositionCorrectionUtils::getSweetSpotFactor( 82 traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength); 83 const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance; 84 85 const bool isFirstChar = pointIndex == 0; 86 const bool isProximity = isProximityDicNode(traverseSession, dicNode); 87 float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_CHAR_PROXIMITY_COST 88 : ScoringParams::PROXIMITY_COST) : 0.0f; 89 if (isProximity && dicNode->getProximityCorrectionCount() == 0) { 90 cost += ScoringParams::FIRST_PROXIMITY_COST; 91 } 92 if (dicNode->getNodeCodePointCount() == 2) { 93 // At the second character of the current word, we check if the first char is uppercase 94 // and the word is a second or later word of a multiple word suggestion. We demote it 95 // if so. 96 const bool isSecondOrLaterWordFirstCharUppercase = 97 dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase(); 98 if (isSecondOrLaterWordFirstCharUppercase) { 99 cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE; 100 } 101 } 102 return weightedDistance + cost; 103 } 104 105 bool isProximityDicNode(const DicTraverseSession *const traverseSession, 106 const DicNode *const dicNode) const { 107 const int pointIndex = dicNode->getInputIndex(0); 108 const int primaryCodePoint = CharUtils::toBaseLowerCase( 109 traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex)); 110 const int dicNodeChar = CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()); 111 return primaryCodePoint != dicNodeChar; 112 } 113 114 float getTranspositionCost(const DicTraverseSession *const traverseSession, 115 const DicNode *const parentDicNode, const DicNode *const dicNode) const { 116 const int16_t parentPointIndex = parentDicNode->getInputIndex(0); 117 const int prevCodePoint = parentDicNode->getNodeCodePoint(); 118 const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( 119 parentPointIndex + 1, CharUtils::toBaseLowerCase(prevCodePoint)); 120 const int codePoint = dicNode->getNodeCodePoint(); 121 const float distance2 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( 122 parentPointIndex, CharUtils::toBaseLowerCase(codePoint)); 123 const float distance = distance1 + distance2; 124 const float weightedLengthDistance = 125 distance * ScoringParams::DISTANCE_WEIGHT_LENGTH; 126 return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance; 127 } 128 129 float getInsertionCost(const DicTraverseSession *const traverseSession, 130 const DicNode *const parentDicNode, const DicNode *const dicNode) const { 131 const int16_t insertedPointIndex = parentDicNode->getInputIndex(0); 132 const int prevCodePoint = traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt( 133 insertedPointIndex); 134 const int currentCodePoint = dicNode->getNodeCodePoint(); 135 const bool sameCodePoint = prevCodePoint == currentCodePoint; 136 const bool existsAdjacentProximityChars = traverseSession->getProximityInfoState(0) 137 ->existsAdjacentProximityChars(insertedPointIndex); 138 const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength( 139 insertedPointIndex + 1, CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint())); 140 const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH; 141 const bool singleChar = dicNode->getNodeCodePointCount() == 1; 142 float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f); 143 if (sameCodePoint) { 144 cost += ScoringParams::INSERTION_COST_SAME_CHAR; 145 } else if (existsAdjacentProximityChars) { 146 cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR; 147 } else { 148 cost += ScoringParams::INSERTION_COST; 149 } 150 return cost + weightedDistance; 151 } 152 153 float getNewWordSpatialCost(const DicTraverseSession *const traverseSession, 154 const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const { 155 return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier(); 156 } 157 158 float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession, 159 const DicNode *const dicNode, 160 MultiBigramMap *const multiBigramMap) const { 161 return DicNodeUtils::getBigramNodeImprobability( 162 traverseSession->getDictionaryStructurePolicy(), 163 dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; 164 } 165 166 float getCompletionCost(const DicTraverseSession *const traverseSession, 167 const DicNode *const dicNode) const { 168 // The auto completion starts when the input index is same as the input size 169 const bool firstCompletion = dicNode->getInputIndex(0) 170 == traverseSession->getInputSize(); 171 // TODO: Change the cost for the first completion for the gesture? 172 const float cost = firstCompletion ? ScoringParams::COST_FIRST_COMPLETION 173 : ScoringParams::COST_COMPLETION; 174 return cost; 175 } 176 177 float getTerminalLanguageCost(const DicTraverseSession *const traverseSession, 178 const DicNode *const dicNode, const float dicNodeLanguageImprobability) const { 179 return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; 180 } 181 182 float getTerminalInsertionCost(const DicTraverseSession *const traverseSession, 183 const DicNode *const dicNode) const { 184 const int inputIndex = dicNode->getInputIndex(0); 185 const int inputSize = traverseSession->getInputSize(); 186 ASSERT(inputIndex < inputSize); 187 // TODO: Implement more efficient logic 188 return ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex); 189 } 190 191 AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const { 192 return false; 193 } 194 195 AK_FORCE_INLINE float getAdditionalProximityCost() const { 196 return ScoringParams::ADDITIONAL_PROXIMITY_COST; 197 } 198 199 AK_FORCE_INLINE float getSubstitutionCost() const { 200 return ScoringParams::SUBSTITUTION_COST; 201 } 202 203 AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession, 204 const DicNode *const dicNode) const { 205 const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD; 206 return cost * traverseSession->getMultiWordCostMultiplier(); 207 } 208 209 ErrorTypeUtils::ErrorType getErrorType(const CorrectionType correctionType, 210 const DicTraverseSession *const traverseSession, 211 const DicNode *const parentDicNode, const DicNode *const dicNode) const; 212 213 private: 214 DISALLOW_COPY_AND_ASSIGN(TypingWeighting); 215 static const TypingWeighting sInstance; 216 217 TypingWeighting() {} 218 ~TypingWeighting() {} 219 }; 220 } // namespace latinime 221 #endif // LATINIME_TYPING_WEIGHTING_H 222