Home | History | Annotate | Download | only in policy
      1 /*
      2  * Copyright (C) 2013 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "suggest/core/policy/weighting.h"
     18 
     19 #include "defines.h"
     20 #include "suggest/core/dicnode/dic_node.h"
     21 #include "suggest/core/dicnode/dic_node_profiler.h"
     22 #include "suggest/core/dicnode/dic_node_utils.h"
     23 #include "suggest/core/session/dic_traverse_session.h"
     24 
     25 namespace latinime {
     26 
     27 class MultiBigramMap;
     28 
     29 static inline void profile(const CorrectionType correctionType, DicNode *const node) {
     30 #if DEBUG_DICT
     31     switch (correctionType) {
     32     case CT_OMISSION:
     33         PROF_OMISSION(node->mProfiler);
     34         return;
     35     case CT_ADDITIONAL_PROXIMITY:
     36         PROF_ADDITIONAL_PROXIMITY(node->mProfiler);
     37         return;
     38     case CT_SUBSTITUTION:
     39         PROF_SUBSTITUTION(node->mProfiler);
     40         return;
     41     case CT_NEW_WORD_SPACE_OMISSION:
     42         PROF_NEW_WORD(node->mProfiler);
     43         return;
     44     case CT_MATCH:
     45         PROF_MATCH(node->mProfiler);
     46         return;
     47     case CT_COMPLETION:
     48         PROF_COMPLETION(node->mProfiler);
     49         return;
     50     case CT_TERMINAL:
     51         PROF_TERMINAL(node->mProfiler);
     52         return;
     53     case CT_TERMINAL_INSERTION:
     54         PROF_TERMINAL_INSERTION(node->mProfiler);
     55         return;
     56     case CT_NEW_WORD_SPACE_SUBSTITUTION:
     57         PROF_SPACE_SUBSTITUTION(node->mProfiler);
     58         return;
     59     case CT_INSERTION:
     60         PROF_INSERTION(node->mProfiler);
     61         return;
     62     case CT_TRANSPOSITION:
     63         PROF_TRANSPOSITION(node->mProfiler);
     64         return;
     65     default:
     66         // do nothing
     67         return;
     68     }
     69 #else
     70     // do nothing
     71 #endif
     72 }
     73 
     74 /* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting,
     75         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
     76         const DicNode *const parentDicNode, DicNode *const dicNode,
     77         MultiBigramMap *const multiBigramMap) {
     78     const int inputSize = traverseSession->getInputSize();
     79     DicNode_InputStateG inputStateG;
     80     inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default
     81     const float spatialCost = Weighting::getSpatialCost(weighting, correctionType,
     82             traverseSession, parentDicNode, dicNode, &inputStateG);
     83     const float languageCost = Weighting::getLanguageCost(weighting, correctionType,
     84             traverseSession, parentDicNode, dicNode, multiBigramMap);
     85     const ErrorType errorType = weighting->getErrorType(correctionType, traverseSession,
     86             parentDicNode, dicNode);
     87     profile(correctionType, dicNode);
     88     if (inputStateG.mNeedsToUpdateInputStateG) {
     89         dicNode->updateInputIndexG(&inputStateG);
     90     } else {
     91         dicNode->forwardInputIndex(0, getForwardInputCount(correctionType),
     92                 (correctionType == CT_TRANSPOSITION));
     93     }
     94     dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(),
     95             inputSize, errorType);
     96     if (CT_NEW_WORD_SPACE_OMISSION == correctionType) {
     97         // When we are on a terminal, we save the current distance for evaluating
     98         // when to auto-commit partial suggestions.
     99         dicNode->saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet();
    100     }
    101 }
    102 
    103 /* static */ float Weighting::getSpatialCost(const Weighting *const weighting,
    104         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
    105         const DicNode *const parentDicNode, const DicNode *const dicNode,
    106         DicNode_InputStateG *const inputStateG) {
    107     switch(correctionType) {
    108     case CT_OMISSION:
    109         return weighting->getOmissionCost(parentDicNode, dicNode);
    110     case CT_ADDITIONAL_PROXIMITY:
    111         // only used for typing
    112         return weighting->getAdditionalProximityCost();
    113     case CT_SUBSTITUTION:
    114         // only used for typing
    115         return weighting->getSubstitutionCost();
    116     case CT_NEW_WORD_SPACE_OMISSION:
    117         return weighting->getNewWordSpatialCost(traverseSession, dicNode, inputStateG);
    118     case CT_MATCH:
    119         return weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
    120     case CT_COMPLETION:
    121         return weighting->getCompletionCost(traverseSession, dicNode);
    122     case CT_TERMINAL:
    123         return weighting->getTerminalSpatialCost(traverseSession, dicNode);
    124     case CT_TERMINAL_INSERTION:
    125         return weighting->getTerminalInsertionCost(traverseSession, dicNode);
    126     case CT_NEW_WORD_SPACE_SUBSTITUTION:
    127         return weighting->getSpaceSubstitutionCost(traverseSession, dicNode);
    128     case CT_INSERTION:
    129         return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode);
    130     case CT_TRANSPOSITION:
    131         return weighting->getTranspositionCost(traverseSession, parentDicNode, dicNode);
    132     default:
    133         return 0.0f;
    134     }
    135 }
    136 
    137 /* static */ float Weighting::getLanguageCost(const Weighting *const weighting,
    138         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
    139         const DicNode *const parentDicNode, const DicNode *const dicNode,
    140         MultiBigramMap *const multiBigramMap) {
    141     switch(correctionType) {
    142     case CT_OMISSION:
    143         return 0.0f;
    144     case CT_SUBSTITUTION:
    145         return 0.0f;
    146     case CT_NEW_WORD_SPACE_OMISSION:
    147         return weighting->getNewWordBigramLanguageCost(
    148                 traverseSession, parentDicNode, multiBigramMap);
    149     case CT_MATCH:
    150         return 0.0f;
    151     case CT_COMPLETION:
    152         return 0.0f;
    153     case CT_TERMINAL: {
    154         const float languageImprobability =
    155                 DicNodeUtils::getBigramNodeImprobability(
    156                         traverseSession->getDictionaryStructurePolicy(), dicNode, multiBigramMap);
    157         return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
    158     }
    159     case CT_TERMINAL_INSERTION:
    160         return 0.0f;
    161     case CT_NEW_WORD_SPACE_SUBSTITUTION:
    162         return weighting->getNewWordBigramLanguageCost(
    163                 traverseSession, parentDicNode, multiBigramMap);
    164     case CT_INSERTION:
    165         return 0.0f;
    166     case CT_TRANSPOSITION:
    167         return 0.0f;
    168     default:
    169         return 0.0f;
    170     }
    171 }
    172 
    173 /* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) {
    174     switch(correctionType) {
    175         case CT_OMISSION:
    176             return 0;
    177         case CT_ADDITIONAL_PROXIMITY:
    178             return 0; /* 0 because CT_MATCH will be called */
    179         case CT_SUBSTITUTION:
    180             return 0; /* 0 because CT_MATCH will be called */
    181         case CT_NEW_WORD_SPACE_OMISSION:
    182             return 0;
    183         case CT_MATCH:
    184             return 1;
    185         case CT_COMPLETION:
    186             return 1;
    187         case CT_TERMINAL:
    188             return 0;
    189         case CT_TERMINAL_INSERTION:
    190             return 1;
    191         case CT_NEW_WORD_SPACE_SUBSTITUTION:
    192             return 1;
    193         case CT_INSERTION:
    194             return 2; /* look ahead + skip the current char */
    195         case CT_TRANSPOSITION:
    196             return 2; /* look ahead + skip the current char */
    197         default:
    198             return 0;
    199     }
    200 }
    201 }  // namespace latinime
    202