Home | History | Annotate | Download | only in policy
      1 /*
      2  * Copyright (C) 2013 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "suggest/core/policy/weighting.h"
     18 
     19 #include "char_utils.h"
     20 #include "defines.h"
     21 #include "suggest/core/dicnode/dic_node.h"
     22 #include "suggest/core/dicnode/dic_node_profiler.h"
     23 #include "suggest/core/dicnode/dic_node_utils.h"
     24 #include "suggest/core/session/dic_traverse_session.h"
     25 
     26 namespace latinime {
     27 
     28 class MultiBigramMap;
     29 
     30 static inline void profile(const CorrectionType correctionType, DicNode *const node) {
     31 #if DEBUG_DICT
     32     switch (correctionType) {
     33     case CT_OMISSION:
     34         PROF_OMISSION(node->mProfiler);
     35         return;
     36     case CT_ADDITIONAL_PROXIMITY:
     37         PROF_ADDITIONAL_PROXIMITY(node->mProfiler);
     38         return;
     39     case CT_SUBSTITUTION:
     40         PROF_SUBSTITUTION(node->mProfiler);
     41         return;
     42     case CT_NEW_WORD_SPACE_OMITTION:
     43         PROF_NEW_WORD(node->mProfiler);
     44         return;
     45     case CT_MATCH:
     46         PROF_MATCH(node->mProfiler);
     47         return;
     48     case CT_COMPLETION:
     49         PROF_COMPLETION(node->mProfiler);
     50         return;
     51     case CT_TERMINAL:
     52         PROF_TERMINAL(node->mProfiler);
     53         return;
     54     case CT_NEW_WORD_SPACE_SUBSTITUTION:
     55         PROF_SPACE_SUBSTITUTION(node->mProfiler);
     56         return;
     57     case CT_INSERTION:
     58         PROF_INSERTION(node->mProfiler);
     59         return;
     60     case CT_TRANSPOSITION:
     61         PROF_TRANSPOSITION(node->mProfiler);
     62         return;
     63     default:
     64         // do nothing
     65         return;
     66     }
     67 #else
     68     // do nothing
     69 #endif
     70 }
     71 
     72 /* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting,
     73         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
     74         const DicNode *const parentDicNode, DicNode *const dicNode,
     75         MultiBigramMap *const multiBigramMap) {
     76     const int inputSize = traverseSession->getInputSize();
     77     DicNode_InputStateG inputStateG;
     78     inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default
     79     const float spatialCost = Weighting::getSpatialCost(weighting, correctionType,
     80             traverseSession, parentDicNode, dicNode, &inputStateG);
     81     const float languageCost = Weighting::getLanguageCost(weighting, correctionType,
     82             traverseSession, parentDicNode, dicNode, multiBigramMap);
     83     const ErrorType errorType = weighting->getErrorType(correctionType, traverseSession,
     84             parentDicNode, dicNode);
     85     profile(correctionType, dicNode);
     86     if (inputStateG.mNeedsToUpdateInputStateG) {
     87         dicNode->updateInputIndexG(&inputStateG);
     88     } else {
     89         dicNode->forwardInputIndex(0, getForwardInputCount(correctionType),
     90                 (correctionType == CT_TRANSPOSITION));
     91     }
     92     dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(),
     93             inputSize, errorType);
     94 }
     95 
     96 /* static */ float Weighting::getSpatialCost(const Weighting *const weighting,
     97         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
     98         const DicNode *const parentDicNode, const DicNode *const dicNode,
     99         DicNode_InputStateG *const inputStateG) {
    100     switch(correctionType) {
    101     case CT_OMISSION:
    102         return weighting->getOmissionCost(parentDicNode, dicNode);
    103     case CT_ADDITIONAL_PROXIMITY:
    104         // only used for typing
    105         return weighting->getAdditionalProximityCost();
    106     case CT_SUBSTITUTION:
    107         // only used for typing
    108         return weighting->getSubstitutionCost();
    109     case CT_NEW_WORD_SPACE_OMITTION:
    110         return weighting->getNewWordCost(traverseSession, dicNode);
    111     case CT_MATCH:
    112         return weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
    113     case CT_COMPLETION:
    114         return weighting->getCompletionCost(traverseSession, dicNode);
    115     case CT_TERMINAL:
    116         return weighting->getTerminalSpatialCost(traverseSession, dicNode);
    117     case CT_NEW_WORD_SPACE_SUBSTITUTION:
    118         return weighting->getSpaceSubstitutionCost(traverseSession, dicNode);
    119     case CT_INSERTION:
    120         return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode);
    121     case CT_TRANSPOSITION:
    122         return weighting->getTranspositionCost(traverseSession, parentDicNode, dicNode);
    123     default:
    124         return 0.0f;
    125     }
    126 }
    127 
    128 /* static */ float Weighting::getLanguageCost(const Weighting *const weighting,
    129         const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
    130         const DicNode *const parentDicNode, const DicNode *const dicNode,
    131         MultiBigramMap *const multiBigramMap) {
    132     switch(correctionType) {
    133     case CT_OMISSION:
    134         return 0.0f;
    135     case CT_SUBSTITUTION:
    136         return 0.0f;
    137     case CT_NEW_WORD_SPACE_OMITTION:
    138         return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap);
    139     case CT_MATCH:
    140         return 0.0f;
    141     case CT_COMPLETION:
    142         return 0.0f;
    143     case CT_TERMINAL: {
    144         const float languageImprobability =
    145                 DicNodeUtils::getBigramNodeImprobability(
    146                         traverseSession->getOffsetDict(), dicNode, multiBigramMap);
    147         return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
    148     }
    149     case CT_NEW_WORD_SPACE_SUBSTITUTION:
    150         return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap);
    151     case CT_INSERTION:
    152         return 0.0f;
    153     case CT_TRANSPOSITION:
    154         return 0.0f;
    155     default:
    156         return 0.0f;
    157     }
    158 }
    159 
    160 /* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) {
    161     switch(correctionType) {
    162         case CT_OMISSION:
    163             return 0;
    164         case CT_ADDITIONAL_PROXIMITY:
    165             return 0;
    166         case CT_SUBSTITUTION:
    167             return 0;
    168         case CT_NEW_WORD_SPACE_OMITTION:
    169             return 0;
    170         case CT_MATCH:
    171             return 1;
    172         case CT_COMPLETION:
    173             return 1;
    174         case CT_TERMINAL:
    175             return 0;
    176         case CT_NEW_WORD_SPACE_SUBSTITUTION:
    177             return 1;
    178         case CT_INSERTION:
    179             return 2;
    180         case CT_TRANSPOSITION:
    181             return 2;
    182         default:
    183             return 0;
    184     }
    185 }
    186 }  // namespace latinime
    187