Home | History | Annotate | Download | only in policy
      1 /*
      2  * Copyright (C) 2013 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "suggest/core/policy/weighting.h"
     18 
     19 #include "defines.h"
     20 #include "suggest/core/dicnode/dic_node.h"
     21 #include "suggest/core/dicnode/dic_node_profiler.h"
     22 #include "suggest/core/dicnode/dic_node_utils.h"
     23 #include "suggest/core/dictionary/error_type_utils.h"
     24 #include "suggest/core/session/dic_traverse_session.h"
     25 
     26 namespace latinime {
     27 
     28 class MultiBigramMap;
     29 
     30 static inline void profile(const CorrectionType correctionType, DicNode *const node) {
     31 #if DEBUG_DICT
     32     switch (correctionType) {
     33     case CT_OMISSION:
     34         PROF_OMISSION(node->mProfiler);
     35         return;
     36     case CT_ADDITIONAL_PROXIMITY:
     37         PROF_ADDITIONAL_PROXIMITY(node->mProfiler);
     38         return;
     39     case CT_SUBSTITUTION:
     40         PROF_SUBSTITUTION(node->mProfiler);
     41         return;
     42     case CT_NEW_WORD_SPACE_OMISSION:
     43         PROF_NEW_WORD(node->mProfiler);
     44         return;
     45     case CT_MATCH:
     46         PROF_MATCH(node->mProfiler);
     47         return;
     48     case CT_COMPLETION:
     49         PROF_COMPLETION(node->mProfiler);
     50         return;
     51     case CT_TERMINAL:
     52         PROF_TERMINAL(node->mProfiler);
     53         return;
     54     case CT_TERMINAL_INSERTION:
     55         PROF_TERMINAL_INSERTION(node->mProfiler);
     56         return;
     57     case CT_NEW_WORD_SPACE_SUBSTITUTION:
     58         PROF_SPACE_SUBSTITUTION(node->mProfiler);
     59         return;
     60     case CT_INSERTION:
     61         PROF_INSERTION(node->mProfiler);
     62         return;
     63     case CT_TRANSPOSITION:
     64         PROF_TRANSPOSITION(node->mProfiler);
     65         return;
     66     default:
     67         // do nothing
     68         return;
     69     }
     70 #else
     71     // do nothing
     72 #endif
     73 }
     74 
     75 /* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting,
     76         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
     77         const DicNode *const parentDicNode, DicNode *const dicNode,
     78         MultiBigramMap *const multiBigramMap) {
     79     const int inputSize = traverseSession->getInputSize();
     80     DicNode_InputStateG inputStateG;
     81     inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default
     82     const float spatialCost = Weighting::getSpatialCost(weighting, correctionType,
     83             traverseSession, parentDicNode, dicNode, &inputStateG);
     84     const float languageCost = Weighting::getLanguageCost(weighting, correctionType,
     85             traverseSession, parentDicNode, dicNode, multiBigramMap);
     86     const ErrorTypeUtils::ErrorType errorType = weighting->getErrorType(correctionType,
     87             traverseSession, parentDicNode, dicNode);
     88     profile(correctionType, dicNode);
     89     if (inputStateG.mNeedsToUpdateInputStateG) {
     90         dicNode->updateInputIndexG(&inputStateG);
     91     } else {
     92         dicNode->forwardInputIndex(0, getForwardInputCount(correctionType),
     93                 (correctionType == CT_TRANSPOSITION));
     94     }
     95     dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(),
     96             inputSize, errorType);
     97     if (CT_NEW_WORD_SPACE_OMISSION == correctionType) {
     98         // When we are on a terminal, we save the current distance for evaluating
     99         // when to auto-commit partial suggestions.
    100         dicNode->saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet();
    101     }
    102 }
    103 
    104 /* static */ float Weighting::getSpatialCost(const Weighting *const weighting,
    105         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
    106         const DicNode *const parentDicNode, const DicNode *const dicNode,
    107         DicNode_InputStateG *const inputStateG) {
    108     switch(correctionType) {
    109     case CT_OMISSION:
    110         return weighting->getOmissionCost(parentDicNode, dicNode);
    111     case CT_ADDITIONAL_PROXIMITY:
    112         // only used for typing
    113         // TODO: Quit calling getMatchedCost().
    114         return weighting->getAdditionalProximityCost()
    115                 + weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
    116     case CT_SUBSTITUTION:
    117         // only used for typing
    118         // TODO: Quit calling getMatchedCost().
    119         return weighting->getSubstitutionCost()
    120                 + weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
    121     case CT_NEW_WORD_SPACE_OMISSION:
    122         return weighting->getSpaceOmissionCost(traverseSession, dicNode, inputStateG);
    123     case CT_MATCH:
    124         return weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
    125     case CT_COMPLETION:
    126         return weighting->getCompletionCost(traverseSession, dicNode);
    127     case CT_TERMINAL:
    128         return weighting->getTerminalSpatialCost(traverseSession, dicNode);
    129     case CT_TERMINAL_INSERTION:
    130         return weighting->getTerminalInsertionCost(traverseSession, dicNode);
    131     case CT_NEW_WORD_SPACE_SUBSTITUTION:
    132         return weighting->getSpaceSubstitutionCost(traverseSession, dicNode);
    133     case CT_INSERTION:
    134         return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode);
    135     case CT_TRANSPOSITION:
    136         return weighting->getTranspositionCost(traverseSession, parentDicNode, dicNode);
    137     default:
    138         return 0.0f;
    139     }
    140 }
    141 
    142 /* static */ float Weighting::getLanguageCost(const Weighting *const weighting,
    143         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
    144         const DicNode *const parentDicNode, const DicNode *const dicNode,
    145         MultiBigramMap *const multiBigramMap) {
    146     switch(correctionType) {
    147     case CT_OMISSION:
    148         return 0.0f;
    149     case CT_SUBSTITUTION:
    150         return 0.0f;
    151     case CT_NEW_WORD_SPACE_OMISSION:
    152         return weighting->getNewWordBigramLanguageCost(
    153                 traverseSession, parentDicNode, multiBigramMap);
    154     case CT_MATCH:
    155         return 0.0f;
    156     case CT_COMPLETION:
    157         return 0.0f;
    158     case CT_TERMINAL: {
    159         const float languageImprobability =
    160                 DicNodeUtils::getBigramNodeImprobability(
    161                         traverseSession->getDictionaryStructurePolicy(), dicNode, multiBigramMap);
    162         return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
    163     }
    164     case CT_TERMINAL_INSERTION:
    165         return 0.0f;
    166     case CT_NEW_WORD_SPACE_SUBSTITUTION:
    167         return weighting->getNewWordBigramLanguageCost(
    168                 traverseSession, parentDicNode, multiBigramMap);
    169     case CT_INSERTION:
    170         return 0.0f;
    171     case CT_TRANSPOSITION:
    172         return 0.0f;
    173     default:
    174         return 0.0f;
    175     }
    176 }
    177 
    178 /* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) {
    179     switch(correctionType) {
    180         case CT_OMISSION:
    181             return 0;
    182         case CT_ADDITIONAL_PROXIMITY:
    183             return 1;
    184         case CT_SUBSTITUTION:
    185             return 1;
    186         case CT_NEW_WORD_SPACE_OMISSION:
    187             return 0;
    188         case CT_MATCH:
    189             return 1;
    190         case CT_COMPLETION:
    191             return 1;
    192         case CT_TERMINAL:
    193             return 0;
    194         case CT_TERMINAL_INSERTION:
    195             return 1;
    196         case CT_NEW_WORD_SPACE_SUBSTITUTION:
    197             return 1;
    198         case CT_INSERTION:
    199             return 2; /* look ahead + skip the current char */
    200         case CT_TRANSPOSITION:
    201             return 2; /* look ahead + skip the current char */
    202         default:
    203             return 0;
    204     }
    205 }
    206 }  // namespace latinime
    207