1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <cstring> 18 #include <vector> 19 20 #include "binary_format.h" 21 #include "dic_node.h" 22 #include "dic_node_utils.h" 23 #include "dic_node_vector.h" 24 #include "multi_bigram_map.h" 25 #include "proximity_info.h" 26 #include "proximity_info_state.h" 27 28 namespace latinime { 29 30 /////////////////////////////// 31 // Node initialization utils // 32 /////////////////////////////// 33 34 /* static */ void DicNodeUtils::initAsRoot(const int rootPos, const uint8_t *const dicRoot, 35 const int prevWordNodePos, DicNode *newRootNode) { 36 int curPos = rootPos; 37 const int pos = curPos; 38 const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &curPos); 39 const int childrenPos = curPos; 40 newRootNode->initAsRoot(pos, childrenPos, childrenCount, prevWordNodePos); 41 } 42 43 /*static */ void DicNodeUtils::initAsRootWithPreviousWord(const int rootPos, 44 const uint8_t *const dicRoot, DicNode *prevWordLastNode, DicNode *newRootNode) { 45 int curPos = rootPos; 46 const int pos = curPos; 47 const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &curPos); 48 const int childrenPos = curPos; 49 newRootNode->initAsRootWithPreviousWord(prevWordLastNode, pos, childrenPos, childrenCount); 50 } 51 52 /* static */ void DicNodeUtils::initByCopy(DicNode *srcNode, DicNode *destNode) { 53 destNode->initByCopy(srcNode); 54 } 55 56 /////////////////////////////////// 57 // Traverse node expansion utils // 58 /////////////////////////////////// 59 60 /* static */ void DicNodeUtils::createAndGetPassingChildNode(DicNode *dicNode, 61 const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, 62 DicNodeVector *childDicNodes) { 63 // Passing multiple chars node. No need to traverse child 64 const int codePoint = dicNode->getNodeTypedCodePoint(); 65 const int baseLowerCaseCodePoint = toBaseLowerCase(codePoint); 66 const bool isMatch = isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, codePoint); 67 if (isMatch || isIntentionalOmissionCodePoint(baseLowerCaseCodePoint)) { 68 childDicNodes->pushPassingChild(dicNode); 69 } 70 } 71 72 /* static */ int DicNodeUtils::createAndGetLeavingChildNode(DicNode *dicNode, int pos, 73 const uint8_t *const dicRoot, const int terminalDepth, const ProximityInfoState *pInfoState, 74 const int pointIndex, const bool exactOnly, const std::vector<int> *const codePointsFilter, 75 const ProximityInfo *const pInfo, DicNodeVector *childDicNodes) { 76 int nextPos = pos; 77 const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos); 78 const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags)); 79 const bool isTerminal = (0 != (BinaryFormat::FLAG_IS_TERMINAL & flags)); 80 const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags); 81 82 int codePoint = BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos); 83 ASSERT(NOT_A_CODE_POINT != codePoint); 84 const int nodeCodePoint = codePoint; 85 // TODO: optimize this 86 int additionalWordBuf[MAX_WORD_LENGTH]; 87 uint16_t additionalSubwordLength = 0; 88 additionalWordBuf[additionalSubwordLength++] = codePoint; 89 90 do { 91 const int nextCodePoint = hasMultipleChars 92 ? BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos) : NOT_A_CODE_POINT; 93 const bool isLastChar = (NOT_A_CODE_POINT == nextCodePoint); 94 if (!isLastChar) { 95 additionalWordBuf[additionalSubwordLength++] = nextCodePoint; 96 } 97 codePoint = nextCodePoint; 98 } while (NOT_A_CODE_POINT != codePoint); 99 100 const int probability = 101 isTerminal ? BinaryFormat::readProbabilityWithoutMovingPointer(dicRoot, pos) : -1; 102 pos = BinaryFormat::skipProbability(flags, pos); 103 int childrenPos = hasChildren ? BinaryFormat::readChildrenPosition(dicRoot, flags, pos) : 0; 104 const int attributesPos = BinaryFormat::skipChildrenPosition(flags, pos); 105 const int siblingPos = BinaryFormat::skipChildrenPosAndAttributes(dicRoot, flags, pos); 106 107 if (isDicNodeFilteredOut(nodeCodePoint, pInfo, codePointsFilter)) { 108 return siblingPos; 109 } 110 if (!isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, nodeCodePoint)) { 111 return siblingPos; 112 } 113 const int childrenCount = hasChildren 114 ? BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &childrenPos) : 0; 115 childDicNodes->pushLeavingChild(dicNode, nextPos, flags, childrenPos, attributesPos, siblingPos, 116 nodeCodePoint, childrenCount, probability, -1 /* bigramProbability */, isTerminal, 117 hasMultipleChars, hasChildren, additionalSubwordLength, additionalWordBuf); 118 return siblingPos; 119 } 120 121 /* static */ bool DicNodeUtils::isDicNodeFilteredOut(const int nodeCodePoint, 122 const ProximityInfo *const pInfo, const std::vector<int> *const codePointsFilter) { 123 const int filterSize = codePointsFilter ? codePointsFilter->size() : 0; 124 if (filterSize <= 0) { 125 return false; 126 } 127 if (pInfo && (pInfo->getKeyIndexOf(nodeCodePoint) == NOT_AN_INDEX 128 || isIntentionalOmissionCodePoint(nodeCodePoint))) { 129 // If normalized nodeCodePoint is not on the keyboard or skippable, this child is never 130 // filtered. 131 return false; 132 } 133 const int lowerCodePoint = toLowerCase(nodeCodePoint); 134 const int baseLowerCodePoint = toBaseCodePoint(lowerCodePoint); 135 // TODO: Avoid linear search 136 for (int i = 0; i < filterSize; ++i) { 137 // Checking if a normalized code point is in filter characters when pInfo is not 138 // null. When pInfo is null, nodeCodePoint is used to check filtering without 139 // normalizing. 140 if ((pInfo && ((*codePointsFilter)[i] == lowerCodePoint 141 || (*codePointsFilter)[i] == baseLowerCodePoint)) 142 || (!pInfo && (*codePointsFilter)[i] == nodeCodePoint)) { 143 return false; 144 } 145 } 146 return true; 147 } 148 149 /* static */ void DicNodeUtils::createAndGetAllLeavingChildNodes(DicNode *dicNode, 150 const uint8_t *const dicRoot, const ProximityInfoState *pInfoState, const int pointIndex, 151 const bool exactOnly, const std::vector<int> *const codePointsFilter, 152 const ProximityInfo *const pInfo, DicNodeVector *childDicNodes) { 153 const int terminalDepth = dicNode->getLeavingDepth(); 154 const int childCount = dicNode->getChildrenCount(); 155 int nextPos = dicNode->getChildrenPos(); 156 for (int i = 0; i < childCount; i++) { 157 const int filterSize = codePointsFilter ? codePointsFilter->size() : 0; 158 nextPos = createAndGetLeavingChildNode(dicNode, nextPos, dicRoot, terminalDepth, pInfoState, 159 pointIndex, exactOnly, codePointsFilter, pInfo, childDicNodes); 160 if (!pInfo && filterSize > 0 && childDicNodes->exceeds(filterSize)) { 161 // All code points have been found. 162 break; 163 } 164 } 165 } 166 167 /* static */ void DicNodeUtils::getAllChildDicNodes(DicNode *dicNode, const uint8_t *const dicRoot, 168 DicNodeVector *childDicNodes) { 169 getProximityChildDicNodes(dicNode, dicRoot, 0, 0, false, childDicNodes); 170 } 171 172 /* static */ void DicNodeUtils::getProximityChildDicNodes(DicNode *dicNode, 173 const uint8_t *const dicRoot, const ProximityInfoState *pInfoState, const int pointIndex, 174 bool exactOnly, DicNodeVector *childDicNodes) { 175 if (dicNode->isTotalInputSizeExceedingLimit()) { 176 return; 177 } 178 if (!dicNode->isLeavingNode()) { 179 DicNodeUtils::createAndGetPassingChildNode(dicNode, pInfoState, pointIndex, exactOnly, 180 childDicNodes); 181 } else { 182 DicNodeUtils::createAndGetAllLeavingChildNodes(dicNode, dicRoot, pInfoState, pointIndex, 183 exactOnly, 0 /* codePointsFilter */, 0 /* pInfo */, 184 childDicNodes); 185 } 186 } 187 188 /////////////////// 189 // Scoring utils // 190 /////////////////// 191 /** 192 * Computes the combined bigram / unigram cost for the given dicNode. 193 */ 194 /* static */ float DicNodeUtils::getBigramNodeImprobability(const uint8_t *const dicRoot, 195 const DicNode *const node, MultiBigramMap *multiBigramMap) { 196 if (node->isImpossibleBigramWord()) { 197 return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); 198 } 199 const int probability = getBigramNodeProbability(dicRoot, node, multiBigramMap); 200 // TODO: This equation to calculate the improbability looks unreasonable. Investigate this. 201 const float cost = static_cast<float>(MAX_PROBABILITY - probability) 202 / static_cast<float>(MAX_PROBABILITY); 203 return cost; 204 } 205 206 /* static */ int DicNodeUtils::getBigramNodeProbability(const uint8_t *const dicRoot, 207 const DicNode *const node, MultiBigramMap *multiBigramMap) { 208 const int unigramProbability = node->getProbability(); 209 const int wordPos = node->getPos(); 210 const int prevWordPos = node->getPrevWordPos(); 211 if (NOT_VALID_WORD == wordPos || NOT_VALID_WORD == prevWordPos) { 212 // Note: Normally wordPos comes from the dictionary and should never equal NOT_VALID_WORD. 213 return backoff(unigramProbability); 214 } 215 if (multiBigramMap) { 216 return multiBigramMap->getBigramProbability( 217 dicRoot, prevWordPos, wordPos, unigramProbability); 218 } 219 return BinaryFormat::getBigramProbability(dicRoot, prevWordPos, wordPos, unigramProbability); 220 } 221 222 /////////////////////////////////////// 223 // Bigram / Unigram dictionary utils // 224 /////////////////////////////////////// 225 226 /* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, 227 const int pointIndex, const bool exactOnly, const int nodeCodePoint) { 228 if (!pInfoState) { 229 return true; 230 } 231 if (exactOnly) { 232 return pInfoState->getPrimaryCodePointAt(pointIndex) == nodeCodePoint; 233 } 234 const ProximityType matchedId = pInfoState->getProximityType(pointIndex, nodeCodePoint, 235 true /* checkProximityChars */); 236 return isProximityChar(matchedId); 237 } 238 239 //////////////// 240 // Char utils // 241 //////////////// 242 243 // TODO: Move to char_utils? 244 /* static */ int DicNodeUtils::appendTwoWords(const int *const src0, const int16_t length0, 245 const int *const src1, const int16_t length1, int *dest) { 246 int actualLength0 = 0; 247 for (int i = 0; i < length0; ++i) { 248 if (src0[i] == 0) { 249 break; 250 } 251 actualLength0 = i + 1; 252 } 253 actualLength0 = min(actualLength0, MAX_WORD_LENGTH); 254 memcpy(dest, src0, actualLength0 * sizeof(dest[0])); 255 if (!src1 || length1 == 0) { 256 return actualLength0; 257 } 258 int actualLength1 = 0; 259 for (int i = 0; i < length1; ++i) { 260 if (src1[i] == 0) { 261 break; 262 } 263 actualLength1 = i + 1; 264 } 265 actualLength1 = min(actualLength1, MAX_WORD_LENGTH - actualLength0 - 1); 266 memcpy(&dest[actualLength0], src1, actualLength1 * sizeof(dest[0])); 267 return actualLength0 + actualLength1; 268 } 269 } // namespace latinime 270