Home | History | Annotate | Download | only in typing
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LATINIME_TYPING_WEIGHTING_H
     18 #define LATINIME_TYPING_WEIGHTING_H
     19 
     20 #include "defines.h"
     21 #include "suggest/core/dicnode/dic_node_utils.h"
     22 #include "suggest/core/layout/touch_position_correction_utils.h"
     23 #include "suggest/core/policy/weighting.h"
     24 #include "suggest/core/session/dic_traverse_session.h"
     25 #include "suggest/policyimpl/typing/scoring_params.h"
     26 #include "utils/char_utils.h"
     27 
     28 namespace latinime {
     29 
     30 class DicNode;
     31 struct DicNode_InputStateG;
     32 class MultiBigramMap;
     33 
     34 class TypingWeighting : public Weighting {
     35  public:
     36     static const TypingWeighting *getInstance() { return &sInstance; }
     37 
     38  protected:
     39     float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
     40             const DicNode *const dicNode) const {
     41         float cost = 0.0f;
     42         if (dicNode->hasMultipleWords()) {
     43             cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST;
     44         }
     45         if (dicNode->getProximityCorrectionCount() > 0) {
     46             cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST;
     47         }
     48         if (dicNode->getEditCorrectionCount() > 0) {
     49             cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST;
     50         }
     51         return cost;
     52     }
     53 
     54     float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const {
     55         const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
     56         const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
     57         // If the traversal omitted the first letter then the dicNode should now be on the second.
     58         const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2;
     59         float cost = 0.0f;
     60         if (isZeroCostOmission) {
     61             cost = 0.0f;
     62         } else if (isFirstLetterOmission) {
     63             cost = ScoringParams::OMISSION_COST_FIRST_CHAR;
     64         } else {
     65             cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR
     66                     : ScoringParams::OMISSION_COST;
     67         }
     68         return cost;
     69     }
     70 
     71     float getMatchedCost(const DicTraverseSession *const traverseSession,
     72             const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
     73         const int pointIndex = dicNode->getInputIndex(0);
     74         // Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on
     75         // the keyboard (like accented letters)
     76         const float normalizedSquaredLength = traverseSession->getProximityInfoState(0)
     77                 ->getPointToKeyLength(pointIndex,
     78                         CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
     79         const float normalizedDistance = TouchPositionCorrectionUtils::getSweetSpotFactor(
     80                 traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength);
     81         const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance;
     82 
     83         const bool isFirstChar = pointIndex == 0;
     84         const bool isProximity = isProximityDicNode(traverseSession, dicNode);
     85         float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_CHAR_PROXIMITY_COST
     86                 : ScoringParams::PROXIMITY_COST) : 0.0f;
     87         if (isProximity && dicNode->getProximityCorrectionCount() == 0) {
     88             cost += ScoringParams::FIRST_PROXIMITY_COST;
     89         }
     90         if (dicNode->getNodeCodePointCount() == 2) {
     91             // At the second character of the current word, we check if the first char is uppercase
     92             // and the word is a second or later word of a multiple word suggestion. We demote it
     93             // if so.
     94             const bool isSecondOrLaterWordFirstCharUppercase =
     95                     dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase();
     96             if (isSecondOrLaterWordFirstCharUppercase) {
     97                 cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE;
     98             }
     99         }
    100         return weightedDistance + cost;
    101     }
    102 
    103     bool isProximityDicNode(const DicTraverseSession *const traverseSession,
    104             const DicNode *const dicNode) const {
    105         const int pointIndex = dicNode->getInputIndex(0);
    106         const int primaryCodePoint = CharUtils::toBaseLowerCase(
    107                 traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex));
    108         const int dicNodeChar = CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint());
    109         return primaryCodePoint != dicNodeChar;
    110     }
    111 
    112     float getTranspositionCost(const DicTraverseSession *const traverseSession,
    113             const DicNode *const parentDicNode, const DicNode *const dicNode) const {
    114         const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
    115         const int prevCodePoint = parentDicNode->getNodeCodePoint();
    116         const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
    117                 parentPointIndex + 1, CharUtils::toBaseLowerCase(prevCodePoint));
    118         const int codePoint = dicNode->getNodeCodePoint();
    119         const float distance2 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
    120                 parentPointIndex, CharUtils::toBaseLowerCase(codePoint));
    121         const float distance = distance1 + distance2;
    122         const float weightedLengthDistance =
    123                 distance * ScoringParams::DISTANCE_WEIGHT_LENGTH;
    124         return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance;
    125     }
    126 
    127     float getInsertionCost(const DicTraverseSession *const traverseSession,
    128             const DicNode *const parentDicNode, const DicNode *const dicNode) const {
    129         const int16_t insertedPointIndex = parentDicNode->getInputIndex(0);
    130         const int prevCodePoint = traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(
    131                 insertedPointIndex);
    132         const int currentCodePoint = dicNode->getNodeCodePoint();
    133         const bool sameCodePoint = prevCodePoint == currentCodePoint;
    134         const bool existsAdjacentProximityChars = traverseSession->getProximityInfoState(0)
    135                 ->existsAdjacentProximityChars(insertedPointIndex);
    136         const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
    137                 insertedPointIndex + 1, CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
    138         const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH;
    139         const bool singleChar = dicNode->getNodeCodePointCount() == 1;
    140         float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f);
    141         if (sameCodePoint) {
    142             cost += ScoringParams::INSERTION_COST_SAME_CHAR;
    143         } else if (existsAdjacentProximityChars) {
    144             cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR;
    145         } else {
    146             cost += ScoringParams::INSERTION_COST;
    147         }
    148         return cost + weightedDistance;
    149     }
    150 
    151     float getNewWordSpatialCost(const DicTraverseSession *const traverseSession,
    152             const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
    153         return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier();
    154     }
    155 
    156     float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession,
    157             const DicNode *const dicNode,
    158             MultiBigramMap *const multiBigramMap) const {
    159         return DicNodeUtils::getBigramNodeImprobability(
    160                 traverseSession->getDictionaryStructurePolicy(),
    161                 dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
    162     }
    163 
    164     float getCompletionCost(const DicTraverseSession *const traverseSession,
    165             const DicNode *const dicNode) const {
    166         // The auto completion starts when the input index is same as the input size
    167         const bool firstCompletion = dicNode->getInputIndex(0)
    168                 == traverseSession->getInputSize();
    169         // TODO: Change the cost for the first completion for the gesture?
    170         const float cost = firstCompletion ? ScoringParams::COST_FIRST_LOOKAHEAD
    171                 : ScoringParams::COST_LOOKAHEAD;
    172         return cost;
    173     }
    174 
    175     float getTerminalLanguageCost(const DicTraverseSession *const traverseSession,
    176             const DicNode *const dicNode, const float dicNodeLanguageImprobability) const {
    177         return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
    178     }
    179 
    180     float getTerminalInsertionCost(const DicTraverseSession *const traverseSession,
    181             const DicNode *const dicNode) const {
    182         const int inputIndex = dicNode->getInputIndex(0);
    183         const int inputSize = traverseSession->getInputSize();
    184         ASSERT(inputIndex < inputSize);
    185         // TODO: Implement more efficient logic
    186         return  ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex);
    187     }
    188 
    189     AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const {
    190         return false;
    191     }
    192 
    193     AK_FORCE_INLINE float getAdditionalProximityCost() const {
    194         return ScoringParams::ADDITIONAL_PROXIMITY_COST;
    195     }
    196 
    197     AK_FORCE_INLINE float getSubstitutionCost() const {
    198         return ScoringParams::SUBSTITUTION_COST;
    199     }
    200 
    201     AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
    202             const DicNode *const dicNode) const {
    203         const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD;
    204         return cost * traverseSession->getMultiWordCostMultiplier();
    205     }
    206 
    207     ErrorType getErrorType(const CorrectionType correctionType,
    208             const DicTraverseSession *const traverseSession,
    209             const DicNode *const parentDicNode, const DicNode *const dicNode) const;
    210 
    211  private:
    212     DISALLOW_COPY_AND_ASSIGN(TypingWeighting);
    213     static const TypingWeighting sInstance;
    214 
    215     TypingWeighting() {}
    216     ~TypingWeighting() {}
    217 };
    218 } // namespace latinime
    219 #endif // LATINIME_TYPING_WEIGHTING_H
    220