Home | History | Annotate | Download | only in dicnode
      1 /*
      2  * Copyright (C) 2012 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include <cstring>
     18 #include <vector>
     19 
     20 #include "binary_format.h"
     21 #include "dic_node.h"
     22 #include "dic_node_utils.h"
     23 #include "dic_node_vector.h"
     24 #include "multi_bigram_map.h"
     25 #include "proximity_info.h"
     26 #include "proximity_info_state.h"
     27 
     28 namespace latinime {
     29 
     30 ///////////////////////////////
     31 // Node initialization utils //
     32 ///////////////////////////////
     33 
     34 /* static */ void DicNodeUtils::initAsRoot(const int rootPos, const uint8_t *const dicRoot,
     35         const int prevWordNodePos, DicNode *newRootNode) {
     36     int curPos = rootPos;
     37     const int pos = curPos;
     38     const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &curPos);
     39     const int childrenPos = curPos;
     40     newRootNode->initAsRoot(pos, childrenPos, childrenCount, prevWordNodePos);
     41 }
     42 
     43 /*static */ void DicNodeUtils::initAsRootWithPreviousWord(const int rootPos,
     44         const uint8_t *const dicRoot, DicNode *prevWordLastNode, DicNode *newRootNode) {
     45     int curPos = rootPos;
     46     const int pos = curPos;
     47     const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &curPos);
     48     const int childrenPos = curPos;
     49     newRootNode->initAsRootWithPreviousWord(prevWordLastNode, pos, childrenPos, childrenCount);
     50 }
     51 
     52 /* static */ void DicNodeUtils::initByCopy(DicNode *srcNode, DicNode *destNode) {
     53     destNode->initByCopy(srcNode);
     54 }
     55 
     56 ///////////////////////////////////
     57 // Traverse node expansion utils //
     58 ///////////////////////////////////
     59 
     60 /* static */ void DicNodeUtils::createAndGetPassingChildNode(DicNode *dicNode,
     61         const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly,
     62         DicNodeVector *childDicNodes) {
     63     // Passing multiple chars node. No need to traverse child
     64     const int codePoint = dicNode->getNodeTypedCodePoint();
     65     const int baseLowerCaseCodePoint = toBaseLowerCase(codePoint);
     66     const bool isMatch = isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, codePoint);
     67     if (isMatch || isIntentionalOmissionCodePoint(baseLowerCaseCodePoint)) {
     68         childDicNodes->pushPassingChild(dicNode);
     69     }
     70 }
     71 
     72 /* static */ int DicNodeUtils::createAndGetLeavingChildNode(DicNode *dicNode, int pos,
     73         const uint8_t *const dicRoot, const int terminalDepth, const ProximityInfoState *pInfoState,
     74         const int pointIndex, const bool exactOnly, const std::vector<int> *const codePointsFilter,
     75         const ProximityInfo *const pInfo, DicNodeVector *childDicNodes) {
     76     int nextPos = pos;
     77     const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos);
     78     const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags));
     79     const bool isTerminal = (0 != (BinaryFormat::FLAG_IS_TERMINAL & flags));
     80     const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags);
     81 
     82     int codePoint = BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos);
     83     ASSERT(NOT_A_CODE_POINT != codePoint);
     84     const int nodeCodePoint = codePoint;
     85     // TODO: optimize this
     86     int additionalWordBuf[MAX_WORD_LENGTH];
     87     uint16_t additionalSubwordLength = 0;
     88     additionalWordBuf[additionalSubwordLength++] = codePoint;
     89 
     90     do {
     91         const int nextCodePoint = hasMultipleChars
     92                 ? BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos) : NOT_A_CODE_POINT;
     93         const bool isLastChar = (NOT_A_CODE_POINT == nextCodePoint);
     94         if (!isLastChar) {
     95             additionalWordBuf[additionalSubwordLength++] = nextCodePoint;
     96         }
     97         codePoint = nextCodePoint;
     98     } while (NOT_A_CODE_POINT != codePoint);
     99 
    100     const int probability =
    101             isTerminal ? BinaryFormat::readProbabilityWithoutMovingPointer(dicRoot, pos) : -1;
    102     pos = BinaryFormat::skipProbability(flags, pos);
    103     int childrenPos = hasChildren ? BinaryFormat::readChildrenPosition(dicRoot, flags, pos) : 0;
    104     const int attributesPos = BinaryFormat::skipChildrenPosition(flags, pos);
    105     const int siblingPos = BinaryFormat::skipChildrenPosAndAttributes(dicRoot, flags, pos);
    106 
    107     if (isDicNodeFilteredOut(nodeCodePoint, pInfo, codePointsFilter)) {
    108         return siblingPos;
    109     }
    110     if (!isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, nodeCodePoint)) {
    111         return siblingPos;
    112     }
    113     const int childrenCount = hasChildren
    114             ? BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &childrenPos) : 0;
    115     childDicNodes->pushLeavingChild(dicNode, nextPos, flags, childrenPos, attributesPos, siblingPos,
    116             nodeCodePoint, childrenCount, probability, -1 /* bigramProbability */, isTerminal,
    117             hasMultipleChars, hasChildren, additionalSubwordLength, additionalWordBuf);
    118     return siblingPos;
    119 }
    120 
    121 /* static */ bool DicNodeUtils::isDicNodeFilteredOut(const int nodeCodePoint,
    122         const ProximityInfo *const pInfo, const std::vector<int> *const codePointsFilter) {
    123     const int filterSize = codePointsFilter ? codePointsFilter->size() : 0;
    124     if (filterSize <= 0) {
    125         return false;
    126     }
    127     if (pInfo && (pInfo->getKeyIndexOf(nodeCodePoint) == NOT_AN_INDEX
    128             || isIntentionalOmissionCodePoint(nodeCodePoint))) {
    129         // If normalized nodeCodePoint is not on the keyboard or skippable, this child is never
    130         // filtered.
    131         return false;
    132     }
    133     const int lowerCodePoint = toLowerCase(nodeCodePoint);
    134     const int baseLowerCodePoint = toBaseCodePoint(lowerCodePoint);
    135     // TODO: Avoid linear search
    136     for (int i = 0; i < filterSize; ++i) {
    137         // Checking if a normalized code point is in filter characters when pInfo is not
    138         // null. When pInfo is null, nodeCodePoint is used to check filtering without
    139         // normalizing.
    140         if ((pInfo && ((*codePointsFilter)[i] == lowerCodePoint
    141                 || (*codePointsFilter)[i] == baseLowerCodePoint))
    142                         || (!pInfo && (*codePointsFilter)[i] == nodeCodePoint)) {
    143             return false;
    144         }
    145     }
    146     return true;
    147 }
    148 
    149 /* static */ void DicNodeUtils::createAndGetAllLeavingChildNodes(DicNode *dicNode,
    150         const uint8_t *const dicRoot, const ProximityInfoState *pInfoState, const int pointIndex,
    151         const bool exactOnly, const std::vector<int> *const codePointsFilter,
    152         const ProximityInfo *const pInfo, DicNodeVector *childDicNodes) {
    153     const int terminalDepth = dicNode->getLeavingDepth();
    154     const int childCount = dicNode->getChildrenCount();
    155     int nextPos = dicNode->getChildrenPos();
    156     for (int i = 0; i < childCount; i++) {
    157         const int filterSize = codePointsFilter ? codePointsFilter->size() : 0;
    158         nextPos = createAndGetLeavingChildNode(dicNode, nextPos, dicRoot, terminalDepth, pInfoState,
    159                 pointIndex, exactOnly, codePointsFilter, pInfo, childDicNodes);
    160         if (!pInfo && filterSize > 0 && childDicNodes->exceeds(filterSize)) {
    161             // All code points have been found.
    162             break;
    163         }
    164     }
    165 }
    166 
    167 /* static */ void DicNodeUtils::getAllChildDicNodes(DicNode *dicNode, const uint8_t *const dicRoot,
    168         DicNodeVector *childDicNodes) {
    169     getProximityChildDicNodes(dicNode, dicRoot, 0, 0, false, childDicNodes);
    170 }
    171 
    172 /* static */ void DicNodeUtils::getProximityChildDicNodes(DicNode *dicNode,
    173         const uint8_t *const dicRoot, const ProximityInfoState *pInfoState, const int pointIndex,
    174         bool exactOnly, DicNodeVector *childDicNodes) {
    175     if (dicNode->isTotalInputSizeExceedingLimit()) {
    176         return;
    177     }
    178     if (!dicNode->isLeavingNode()) {
    179         DicNodeUtils::createAndGetPassingChildNode(dicNode, pInfoState, pointIndex, exactOnly,
    180                 childDicNodes);
    181     } else {
    182         DicNodeUtils::createAndGetAllLeavingChildNodes(dicNode, dicRoot, pInfoState, pointIndex,
    183                 exactOnly, 0 /* codePointsFilter */, 0 /* pInfo */,
    184                 childDicNodes);
    185     }
    186 }
    187 
    188 ///////////////////
    189 // Scoring utils //
    190 ///////////////////
    191 /**
    192  * Computes the combined bigram / unigram cost for the given dicNode.
    193  */
    194 /* static */ float DicNodeUtils::getBigramNodeImprobability(const uint8_t *const dicRoot,
    195         const DicNode *const node, MultiBigramMap *multiBigramMap) {
    196     if (node->isImpossibleBigramWord()) {
    197         return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
    198     }
    199     const int probability = getBigramNodeProbability(dicRoot, node, multiBigramMap);
    200     // TODO: This equation to calculate the improbability looks unreasonable.  Investigate this.
    201     const float cost = static_cast<float>(MAX_PROBABILITY - probability)
    202             / static_cast<float>(MAX_PROBABILITY);
    203     return cost;
    204 }
    205 
    206 /* static */ int DicNodeUtils::getBigramNodeProbability(const uint8_t *const dicRoot,
    207         const DicNode *const node, MultiBigramMap *multiBigramMap) {
    208     const int unigramProbability = node->getProbability();
    209     const int wordPos = node->getPos();
    210     const int prevWordPos = node->getPrevWordPos();
    211     if (NOT_VALID_WORD == wordPos || NOT_VALID_WORD == prevWordPos) {
    212         // Note: Normally wordPos comes from the dictionary and should never equal NOT_VALID_WORD.
    213         return backoff(unigramProbability);
    214     }
    215     if (multiBigramMap) {
    216         return multiBigramMap->getBigramProbability(
    217                 dicRoot, prevWordPos, wordPos, unigramProbability);
    218     }
    219     return BinaryFormat::getBigramProbability(dicRoot, prevWordPos, wordPos, unigramProbability);
    220 }
    221 
    222 ///////////////////////////////////////
    223 // Bigram / Unigram dictionary utils //
    224 ///////////////////////////////////////
    225 
    226 /* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState,
    227         const int pointIndex, const bool exactOnly, const int nodeCodePoint) {
    228     if (!pInfoState) {
    229         return true;
    230     }
    231     if (exactOnly) {
    232         return pInfoState->getPrimaryCodePointAt(pointIndex) == nodeCodePoint;
    233     }
    234     const ProximityType matchedId = pInfoState->getProximityType(pointIndex, nodeCodePoint,
    235             true /* checkProximityChars */);
    236     return isProximityChar(matchedId);
    237 }
    238 
    239 ////////////////
    240 // Char utils //
    241 ////////////////
    242 
    243 // TODO: Move to char_utils?
    244 /* static */ int DicNodeUtils::appendTwoWords(const int *const src0, const int16_t length0,
    245         const int *const src1, const int16_t length1, int *dest) {
    246     int actualLength0 = 0;
    247     for (int i = 0; i < length0; ++i) {
    248         if (src0[i] == 0) {
    249             break;
    250         }
    251         actualLength0 = i + 1;
    252     }
    253     actualLength0 = min(actualLength0, MAX_WORD_LENGTH);
    254     memcpy(dest, src0, actualLength0 * sizeof(dest[0]));
    255     if (!src1 || length1 == 0) {
    256         return actualLength0;
    257     }
    258     int actualLength1 = 0;
    259     for (int i = 0; i < length1; ++i) {
    260         if (src1[i] == 0) {
    261             break;
    262         }
    263         actualLength1 = i + 1;
    264     }
    265     actualLength1 = min(actualLength1, MAX_WORD_LENGTH - actualLength0 - 1);
    266     memcpy(&dest[actualLength0], src1, actualLength1 * sizeof(dest[0]));
    267     return actualLength0 + actualLength1;
    268 }
    269 } // namespace latinime
    270