Home | History | Annotate | Download | only in policy
      1 /*
      2  * Copyright (C) 2013 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "suggest/core/policy/weighting.h"
     18 
     19 #include "defines.h"
     20 #include "suggest/core/dicnode/dic_node.h"
     21 #include "suggest/core/dicnode/dic_node_profiler.h"
     22 #include "suggest/core/dicnode/dic_node_utils.h"
     23 #include "suggest/core/dictionary/error_type_utils.h"
     24 #include "suggest/core/session/dic_traverse_session.h"
     25 
     26 namespace latinime {
     27 
     28 class MultiBigramMap;
     29 
     30 static inline void profile(const CorrectionType correctionType, DicNode *const node) {
     31 #if DEBUG_DICT
     32     switch (correctionType) {
     33     case CT_OMISSION:
     34         PROF_OMISSION(node->mProfiler);
     35         return;
     36     case CT_ADDITIONAL_PROXIMITY:
     37         PROF_ADDITIONAL_PROXIMITY(node->mProfiler);
     38         return;
     39     case CT_SUBSTITUTION:
     40         PROF_SUBSTITUTION(node->mProfiler);
     41         return;
     42     case CT_NEW_WORD_SPACE_OMISSION:
     43         PROF_NEW_WORD(node->mProfiler);
     44         return;
     45     case CT_MATCH:
     46         PROF_MATCH(node->mProfiler);
     47         return;
     48     case CT_COMPLETION:
     49         PROF_COMPLETION(node->mProfiler);
     50         return;
     51     case CT_TERMINAL:
     52         PROF_TERMINAL(node->mProfiler);
     53         return;
     54     case CT_TERMINAL_INSERTION:
     55         PROF_TERMINAL_INSERTION(node->mProfiler);
     56         return;
     57     case CT_NEW_WORD_SPACE_SUBSTITUTION:
     58         PROF_SPACE_SUBSTITUTION(node->mProfiler);
     59         return;
     60     case CT_INSERTION:
     61         PROF_INSERTION(node->mProfiler);
     62         return;
     63     case CT_TRANSPOSITION:
     64         PROF_TRANSPOSITION(node->mProfiler);
     65         return;
     66     default:
     67         // do nothing
     68         return;
     69     }
     70 #else
     71     // do nothing
     72 #endif
     73 }
     74 
     75 /* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting,
     76         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
     77         const DicNode *const parentDicNode, DicNode *const dicNode,
     78         MultiBigramMap *const multiBigramMap) {
     79     const int inputSize = traverseSession->getInputSize();
     80     DicNode_InputStateG inputStateG;
     81     inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default
     82     const float spatialCost = Weighting::getSpatialCost(weighting, correctionType,
     83             traverseSession, parentDicNode, dicNode, &inputStateG);
     84     const float languageCost = Weighting::getLanguageCost(weighting, correctionType,
     85             traverseSession, parentDicNode, dicNode, multiBigramMap);
     86     const ErrorTypeUtils::ErrorType errorType = weighting->getErrorType(correctionType,
     87             traverseSession, parentDicNode, dicNode);
     88     profile(correctionType, dicNode);
     89     if (inputStateG.mNeedsToUpdateInputStateG) {
     90         dicNode->updateInputIndexG(&inputStateG);
     91     } else {
     92         dicNode->forwardInputIndex(0, getForwardInputCount(correctionType),
     93                 (correctionType == CT_TRANSPOSITION));
     94     }
     95     dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(),
     96             inputSize, errorType);
     97     if (CT_NEW_WORD_SPACE_OMISSION == correctionType) {
     98         // When we are on a terminal, we save the current distance for evaluating
     99         // when to auto-commit partial suggestions.
    100         dicNode->saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet();
    101     }
    102 }
    103 
    104 /* static */ float Weighting::getSpatialCost(const Weighting *const weighting,
    105         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
    106         const DicNode *const parentDicNode, const DicNode *const dicNode,
    107         DicNode_InputStateG *const inputStateG) {
    108     switch(correctionType) {
    109     case CT_OMISSION:
    110         return weighting->getOmissionCost(parentDicNode, dicNode);
    111     case CT_ADDITIONAL_PROXIMITY:
    112         // only used for typing
    113         return weighting->getAdditionalProximityCost();
    114     case CT_SUBSTITUTION:
    115         // only used for typing
    116         return weighting->getSubstitutionCost();
    117     case CT_NEW_WORD_SPACE_OMISSION:
    118         return weighting->getNewWordSpatialCost(traverseSession, dicNode, inputStateG);
    119     case CT_MATCH:
    120         return weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
    121     case CT_COMPLETION:
    122         return weighting->getCompletionCost(traverseSession, dicNode);
    123     case CT_TERMINAL:
    124         return weighting->getTerminalSpatialCost(traverseSession, dicNode);
    125     case CT_TERMINAL_INSERTION:
    126         return weighting->getTerminalInsertionCost(traverseSession, dicNode);
    127     case CT_NEW_WORD_SPACE_SUBSTITUTION:
    128         return weighting->getSpaceSubstitutionCost(traverseSession, dicNode);
    129     case CT_INSERTION:
    130         return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode);
    131     case CT_TRANSPOSITION:
    132         return weighting->getTranspositionCost(traverseSession, parentDicNode, dicNode);
    133     default:
    134         return 0.0f;
    135     }
    136 }
    137 
    138 /* static */ float Weighting::getLanguageCost(const Weighting *const weighting,
    139         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
    140         const DicNode *const parentDicNode, const DicNode *const dicNode,
    141         MultiBigramMap *const multiBigramMap) {
    142     switch(correctionType) {
    143     case CT_OMISSION:
    144         return 0.0f;
    145     case CT_SUBSTITUTION:
    146         return 0.0f;
    147     case CT_NEW_WORD_SPACE_OMISSION:
    148         return weighting->getNewWordBigramLanguageCost(
    149                 traverseSession, parentDicNode, multiBigramMap);
    150     case CT_MATCH:
    151         return 0.0f;
    152     case CT_COMPLETION:
    153         return 0.0f;
    154     case CT_TERMINAL: {
    155         const float languageImprobability =
    156                 DicNodeUtils::getBigramNodeImprobability(
    157                         traverseSession->getDictionaryStructurePolicy(), dicNode, multiBigramMap);
    158         return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
    159     }
    160     case CT_TERMINAL_INSERTION:
    161         return 0.0f;
    162     case CT_NEW_WORD_SPACE_SUBSTITUTION:
    163         return weighting->getNewWordBigramLanguageCost(
    164                 traverseSession, parentDicNode, multiBigramMap);
    165     case CT_INSERTION:
    166         return 0.0f;
    167     case CT_TRANSPOSITION:
    168         return 0.0f;
    169     default:
    170         return 0.0f;
    171     }
    172 }
    173 
    174 /* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) {
    175     switch(correctionType) {
    176         case CT_OMISSION:
    177             return 0;
    178         case CT_ADDITIONAL_PROXIMITY:
    179             return 0; /* 0 because CT_MATCH will be called */
    180         case CT_SUBSTITUTION:
    181             return 0; /* 0 because CT_MATCH will be called */
    182         case CT_NEW_WORD_SPACE_OMISSION:
    183             return 0;
    184         case CT_MATCH:
    185             return 1;
    186         case CT_COMPLETION:
    187             return 1;
    188         case CT_TERMINAL:
    189             return 0;
    190         case CT_TERMINAL_INSERTION:
    191             return 1;
    192         case CT_NEW_WORD_SPACE_SUBSTITUTION:
    193             return 1;
    194         case CT_INSERTION:
    195             return 2; /* look ahead + skip the current char */
    196         case CT_TRANSPOSITION:
    197             return 2; /* look ahead + skip the current char */
    198         default:
    199             return 0;
    200     }
    201 }
    202 }  // namespace latinime
    203