1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LATINIME_DIC_NODE_H 18 #define LATINIME_DIC_NODE_H 19 20 #include "defines.h" 21 #include "suggest/core/dicnode/dic_node_profiler.h" 22 #include "suggest/core/dicnode/dic_node_utils.h" 23 #include "suggest/core/dicnode/internal/dic_node_state.h" 24 #include "suggest/core/dicnode/internal/dic_node_properties.h" 25 #include "suggest/core/dictionary/digraph_utils.h" 26 #include "suggest/core/dictionary/error_type_utils.h" 27 #include "suggest/core/layout/proximity_info_state.h" 28 #include "utils/char_utils.h" 29 30 #if DEBUG_DICT 31 #define LOGI_SHOW_ADD_COST_PROP \ 32 do { \ 33 char charBuf[50]; \ 34 INTS_TO_CHARS(getOutputWordBuf(), getNodeCodePointCount(), charBuf, NELEMS(charBuf)); \ 35 AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \ 36 __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \ 37 getInputIndex(0), getNormalizedCompoundDistance(), charBuf); \ 38 } while (0) 39 #define DUMP_WORD_AND_SCORE(header) \ 40 do { \ 41 char charBuf[50]; \ 42 INTS_TO_CHARS(getOutputWordBuf(), \ 43 getNodeCodePointCount() \ 44 + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength(), \ 45 charBuf, NELEMS(charBuf)); \ 46 AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %d, %5f,", header, \ 47 getSpatialDistanceForScoring(), \ 48 mDicNodeState.mDicNodeStateScoring.getLanguageDistance(), \ 49 getNormalizedCompoundDistance(), getRawLength(), charBuf, \ 50 getInputIndex(0), getNormalizedCompoundDistanceAfterFirstWord()); \ 51 } while (0) 52 #else 53 #define LOGI_SHOW_ADD_COST_PROP 54 #define DUMP_WORD_AND_SCORE(header) 55 #endif 56 57 namespace latinime { 58 59 // This struct is purely a bucket to return values. No instances of this struct should be kept. 60 struct DicNode_InputStateG { 61 DicNode_InputStateG() 62 : mNeedsToUpdateInputStateG(false), mPointerId(0), mInputIndex(0), 63 mPrevCodePoint(0), mTerminalDiffCost(0.0f), mRawLength(0.0f), 64 mDoubleLetterLevel(NOT_A_DOUBLE_LETTER) {} 65 66 bool mNeedsToUpdateInputStateG; 67 int mPointerId; 68 int16_t mInputIndex; 69 int mPrevCodePoint; 70 float mTerminalDiffCost; 71 float mRawLength; 72 DoubleLetterLevel mDoubleLetterLevel; 73 }; 74 75 class DicNode { 76 // Caveat: We define Weighting as a friend class of DicNode to let Weighting change 77 // the distance of DicNode. 78 // Caution!!! In general, we avoid using the "friend" access modifier. 79 // This is an exception to explicitly hide DicNode::addCost() from all classes but Weighting. 80 friend class Weighting; 81 82 public: 83 #if DEBUG_DICT 84 DicNodeProfiler mProfiler; 85 #endif 86 87 AK_FORCE_INLINE DicNode() 88 : 89 #if DEBUG_DICT 90 mProfiler(), 91 #endif 92 mDicNodeProperties(), mDicNodeState(), mIsCachedForNextSuggestion(false) {} 93 94 DicNode(const DicNode &dicNode); 95 DicNode &operator=(const DicNode &dicNode); 96 ~DicNode() {} 97 98 // Init for copy 99 void initByCopy(const DicNode *const dicNode) { 100 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 101 mDicNodeProperties.initByCopy(&dicNode->mDicNodeProperties); 102 mDicNodeState.initByCopy(&dicNode->mDicNodeState); 103 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 104 } 105 106 // Init for root with prevWordsPtNodePos which is used for n-gram 107 void initAsRoot(const int rootPtNodeArrayPos, const int *const prevWordsPtNodePos) { 108 mIsCachedForNextSuggestion = false; 109 mDicNodeProperties.init(rootPtNodeArrayPos, prevWordsPtNodePos); 110 mDicNodeState.init(); 111 PROF_NODE_RESET(mProfiler); 112 } 113 114 // Init for root with previous word 115 void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) { 116 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 117 int newPrevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; 118 newPrevWordsPtNodePos[0] = dicNode->mDicNodeProperties.getPtNodePos(); 119 for (size_t i = 1; i < NELEMS(newPrevWordsPtNodePos); ++i) { 120 newPrevWordsPtNodePos[i] = dicNode->getPrevWordsTerminalPtNodePos()[i - 1]; 121 } 122 mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordsPtNodePos); 123 mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState, 124 dicNode->mDicNodeProperties.getDepth()); 125 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 126 } 127 128 void initAsPassingChild(const DicNode *parentDicNode) { 129 mIsCachedForNextSuggestion = parentDicNode->mIsCachedForNextSuggestion; 130 const int codePoint = 131 parentDicNode->mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt( 132 parentDicNode->getNodeCodePointCount()); 133 mDicNodeProperties.init(&parentDicNode->mDicNodeProperties, codePoint); 134 mDicNodeState.initByCopy(&parentDicNode->mDicNodeState); 135 PROF_NODE_COPY(&parentDicNode->mProfiler, mProfiler); 136 } 137 138 void initAsChild(const DicNode *const dicNode, const int ptNodePos, 139 const int childrenPtNodeArrayPos, const int probability, const bool isTerminal, 140 const bool hasChildren, const bool isBlacklistedOrNotAWord, 141 const uint16_t mergedNodeCodePointCount, const int *const mergedNodeCodePoints) { 142 uint16_t newDepth = static_cast<uint16_t>(dicNode->getNodeCodePointCount() + 1); 143 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 144 const uint16_t newLeavingDepth = static_cast<uint16_t>( 145 dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount); 146 mDicNodeProperties.init(ptNodePos, childrenPtNodeArrayPos, mergedNodeCodePoints[0], 147 probability, isTerminal, hasChildren, isBlacklistedOrNotAWord, newDepth, 148 newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordsTerminalPtNodePos()); 149 mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount, 150 mergedNodeCodePoints); 151 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 152 } 153 154 bool isRoot() const { 155 return getNodeCodePointCount() == 0; 156 } 157 158 bool hasChildren() const { 159 return mDicNodeProperties.hasChildren(); 160 } 161 162 bool isLeavingNode() const { 163 ASSERT(getNodeCodePointCount() <= mDicNodeProperties.getLeavingDepth()); 164 return getNodeCodePointCount() == mDicNodeProperties.getLeavingDepth(); 165 } 166 167 AK_FORCE_INLINE bool isFirstLetter() const { 168 return getNodeCodePointCount() == 1; 169 } 170 171 bool isCached() const { 172 return mIsCachedForNextSuggestion; 173 } 174 175 void setCached() { 176 mIsCachedForNextSuggestion = true; 177 } 178 179 // Check if the current word and the previous word can be considered as a valid multiple word 180 // suggestion. 181 bool isValidMultipleWordSuggestion() const { 182 if (isBlacklistedOrNotAWord()) { 183 return false; 184 } 185 // Treat suggestion as invalid if the current and the previous word are single character 186 // words. 187 const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength() 188 - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1; 189 const int currentWordLen = getNodeCodePointCount(); 190 return (prevWordLen != 1 || currentWordLen != 1); 191 } 192 193 bool isFirstCharUppercase() const { 194 const int c = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(0); 195 return CharUtils::isAsciiUpper(c); 196 } 197 198 bool isCompletion(const int inputSize) const { 199 return mDicNodeState.mDicNodeStateInput.getInputIndex(0) >= inputSize; 200 } 201 202 bool canDoLookAheadCorrection(const int inputSize) const { 203 return mDicNodeState.mDicNodeStateInput.getInputIndex(0) < inputSize - 1; 204 } 205 206 // Used to get n-gram probability in DicNodeUtils. 207 int getPtNodePos() const { 208 return mDicNodeProperties.getPtNodePos(); 209 } 210 211 // TODO: Use view class to return PtNodePos array. 212 const int *getPrevWordsTerminalPtNodePos() const { 213 return mDicNodeProperties.getPrevWordsTerminalPtNodePos(); 214 } 215 216 // Used in DicNodeUtils 217 int getChildrenPtNodeArrayPos() const { 218 return mDicNodeProperties.getChildrenPtNodeArrayPos(); 219 } 220 221 int getProbability() const { 222 return mDicNodeProperties.getProbability(); 223 } 224 225 AK_FORCE_INLINE bool isTerminalDicNode() const { 226 const bool isTerminalPtNode = mDicNodeProperties.isTerminal(); 227 const int currentDicNodeDepth = getNodeCodePointCount(); 228 const int terminalDicNodeDepth = mDicNodeProperties.getLeavingDepth(); 229 return isTerminalPtNode && currentDicNodeDepth > 0 230 && currentDicNodeDepth == terminalDicNodeDepth; 231 } 232 233 bool shouldBeFilteredBySafetyNetForBigram() const { 234 const uint16_t currentDepth = getNodeCodePointCount(); 235 const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength() 236 - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1; 237 return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1)); 238 } 239 240 bool hasMatchedOrProximityCodePoints() const { 241 // This DicNode does not have matched or proximity code points when all code points have 242 // been handled as edit corrections or completion so far. 243 const int editCorrectionCount = mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount(); 244 const int completionCount = mDicNodeState.mDicNodeStateScoring.getCompletionCount(); 245 return (editCorrectionCount + completionCount) < getNodeCodePointCount(); 246 } 247 248 bool isTotalInputSizeExceedingLimit() const { 249 // TODO: 3 can be 2? Needs to be investigated. 250 // TODO: Have a const variable for 3 (or 2) 251 return getTotalNodeCodePointCount() > MAX_WORD_LENGTH - 3; 252 } 253 254 void outputResult(int *dest) const { 255 memmove(dest, getOutputWordBuf(), getTotalNodeCodePointCount() * sizeof(dest[0])); 256 DUMP_WORD_AND_SCORE("OUTPUT"); 257 } 258 259 // "Total" in this context (and other methods in this class) means the whole suggestion. When 260 // this represents a multi-word suggestion, the referenced PtNode (in mDicNodeState) is only 261 // the one that corresponds to the last word of the suggestion, and all the previous words 262 // are concatenated together in mDicNodeStateOutput. 263 int getTotalNodeSpaceCount() const { 264 if (!hasMultipleWords()) { 265 return 0; 266 } 267 return CharUtils::getSpaceCount(mDicNodeState.mDicNodeStateOutput.getCodePointBuf(), 268 mDicNodeState.mDicNodeStateOutput.getPrevWordsLength()); 269 } 270 271 int getSecondWordFirstInputIndex(const ProximityInfoState *const pInfoState) const { 272 const int inputIndex = mDicNodeState.mDicNodeStateOutput.getSecondWordFirstInputIndex(); 273 if (inputIndex == NOT_AN_INDEX) { 274 return NOT_AN_INDEX; 275 } else { 276 return pInfoState->getInputIndexOfSampledPoint(inputIndex); 277 } 278 } 279 280 bool hasMultipleWords() const { 281 return mDicNodeState.mDicNodeStateOutput.getPrevWordCount() > 0; 282 } 283 284 int getProximityCorrectionCount() const { 285 return mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount(); 286 } 287 288 int getEditCorrectionCount() const { 289 return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount(); 290 } 291 292 // Used to prune nodes 293 float getNormalizedCompoundDistance() const { 294 return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistance(); 295 } 296 297 // Used to prune nodes 298 float getNormalizedSpatialDistance() const { 299 return mDicNodeState.mDicNodeStateScoring.getSpatialDistance() 300 / static_cast<float>(getInputIndex(0) + 1); 301 } 302 303 // Used to prune nodes 304 float getCompoundDistance() const { 305 return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(); 306 } 307 308 // Used to prune nodes 309 float getCompoundDistance(const float languageWeight) const { 310 return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight); 311 } 312 313 AK_FORCE_INLINE const int *getOutputWordBuf() const { 314 return mDicNodeState.mDicNodeStateOutput.getCodePointBuf(); 315 } 316 317 int getPrevCodePointG(int pointerId) const { 318 return mDicNodeState.mDicNodeStateInput.getPrevCodePoint(pointerId); 319 } 320 321 // Whether the current codepoint can be an intentional omission, in which case the traversal 322 // algorithm will always check for a possible omission here. 323 bool canBeIntentionalOmission() const { 324 return CharUtils::isIntentionalOmissionCodePoint(getNodeCodePoint()); 325 } 326 327 // Whether the omission is so frequent that it should incur zero cost. 328 bool isZeroCostOmission() const { 329 // TODO: do not hardcode and read from header 330 return (getNodeCodePoint() == KEYCODE_SINGLE_QUOTE); 331 } 332 333 // TODO: remove 334 float getTerminalDiffCostG(int path) const { 335 return mDicNodeState.mDicNodeStateInput.getTerminalDiffCost(path); 336 } 337 338 ////////////////////// 339 // Temporary getter // 340 // TODO: Remove // 341 ////////////////////// 342 // TODO: Remove once touch path is merged into ProximityInfoState 343 // Note: Returned codepoint may be a digraph codepoint if the node is in a composite glyph. 344 int getNodeCodePoint() const { 345 const int codePoint = mDicNodeProperties.getDicNodeCodePoint(); 346 const DigraphUtils::DigraphCodePointIndex digraphIndex = 347 mDicNodeState.mDicNodeStateScoring.getDigraphIndex(); 348 if (digraphIndex == DigraphUtils::NOT_A_DIGRAPH_INDEX) { 349 return codePoint; 350 } 351 return DigraphUtils::getDigraphCodePointForIndex(codePoint, digraphIndex); 352 } 353 354 //////////////////////////////// 355 // Utils for cost calculation // 356 //////////////////////////////// 357 AK_FORCE_INLINE bool isSameNodeCodePoint(const DicNode *const dicNode) const { 358 return mDicNodeProperties.getDicNodeCodePoint() 359 == dicNode->mDicNodeProperties.getDicNodeCodePoint(); 360 } 361 362 // TODO: remove 363 // TODO: rename getNextInputIndex 364 int16_t getInputIndex(int pointerId) const { 365 return mDicNodeState.mDicNodeStateInput.getInputIndex(pointerId); 366 } 367 368 //////////////////////////////////// 369 // Getter of features for scoring // 370 //////////////////////////////////// 371 float getSpatialDistanceForScoring() const { 372 return mDicNodeState.mDicNodeStateScoring.getSpatialDistance(); 373 } 374 375 // For space-aware gestures, we store the normalized distance at the char index 376 // that ends the first word of the suggestion. We call this the distance after 377 // first word. 378 float getNormalizedCompoundDistanceAfterFirstWord() const { 379 return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistanceAfterFirstWord(); 380 } 381 382 float getRawLength() const { 383 return mDicNodeState.mDicNodeStateScoring.getRawLength(); 384 } 385 386 DoubleLetterLevel getDoubleLetterLevel() const { 387 return mDicNodeState.mDicNodeStateScoring.getDoubleLetterLevel(); 388 } 389 390 void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) { 391 mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel); 392 } 393 394 bool isInDigraph() const { 395 return mDicNodeState.mDicNodeStateScoring.getDigraphIndex() 396 != DigraphUtils::NOT_A_DIGRAPH_INDEX; 397 } 398 399 void advanceDigraphIndex() { 400 mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex(); 401 } 402 403 ErrorTypeUtils::ErrorType getContainedErrorTypes() const { 404 return mDicNodeState.mDicNodeStateScoring.getContainedErrorTypes(); 405 } 406 407 bool isBlacklistedOrNotAWord() const { 408 return mDicNodeProperties.isBlacklistedOrNotAWord(); 409 } 410 411 inline uint16_t getNodeCodePointCount() const { 412 return mDicNodeProperties.getDepth(); 413 } 414 415 // Returns code point count including spaces 416 inline uint16_t getTotalNodeCodePointCount() const { 417 return getNodeCodePointCount() + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength(); 418 } 419 420 AK_FORCE_INLINE void dump(const char *tag) const { 421 #if DEBUG_DICT 422 DUMP_WORD_AND_SCORE(tag); 423 #if DEBUG_DUMP_ERROR 424 mProfiler.dump(); 425 #endif 426 #endif 427 } 428 429 AK_FORCE_INLINE bool compare(const DicNode *right) const { 430 // Promote exact matches to prevent them from being pruned. 431 const bool leftExactMatch = ErrorTypeUtils::isExactMatch(getContainedErrorTypes()); 432 const bool rightExactMatch = ErrorTypeUtils::isExactMatch(right->getContainedErrorTypes()); 433 if (leftExactMatch != rightExactMatch) { 434 return leftExactMatch; 435 } 436 const float diff = 437 right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance(); 438 static const float MIN_DIFF = 0.000001f; 439 if (diff > MIN_DIFF) { 440 return true; 441 } else if (diff < -MIN_DIFF) { 442 return false; 443 } 444 const int depth = getNodeCodePointCount(); 445 const int depthDiff = right->getNodeCodePointCount() - depth; 446 if (depthDiff != 0) { 447 return depthDiff > 0; 448 } 449 for (int i = 0; i < depth; ++i) { 450 const int codePoint = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i); 451 const int rightCodePoint = 452 right->mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i); 453 if (codePoint != rightCodePoint) { 454 return rightCodePoint > codePoint; 455 } 456 } 457 // Compare pointer values here for stable comparison 458 return this > right; 459 } 460 461 private: 462 DicNodeProperties mDicNodeProperties; 463 DicNodeState mDicNodeState; 464 // TODO: Remove 465 bool mIsCachedForNextSuggestion; 466 467 AK_FORCE_INLINE int getTotalInputIndex() const { 468 int index = 0; 469 for (int i = 0; i < MAX_POINTER_COUNT_G; i++) { 470 index += mDicNodeState.mDicNodeStateInput.getInputIndex(i); 471 } 472 return index; 473 } 474 475 // Caveat: Must not be called outside Weighting 476 // This restriction is guaranteed by "friend" 477 AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost, 478 const bool doNormalization, const int inputSize, 479 const ErrorTypeUtils::ErrorType errorType) { 480 if (DEBUG_GEO_FULL) { 481 LOGI_SHOW_ADD_COST_PROP; 482 } 483 mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization, 484 inputSize, getTotalInputIndex(), errorType); 485 } 486 487 // Saves the current normalized compound distance for space-aware gestures. 488 // See getNormalizedCompoundDistanceAfterFirstWord for details. 489 AK_FORCE_INLINE void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() { 490 mDicNodeState.mDicNodeStateScoring.saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet(); 491 } 492 493 // Caveat: Must not be called outside Weighting 494 // This restriction is guaranteed by "friend" 495 AK_FORCE_INLINE void forwardInputIndex(const int pointerId, const int count, 496 const bool overwritesPrevCodePointByNodeCodePoint) { 497 if (count == 0) { 498 return; 499 } 500 mDicNodeState.mDicNodeStateInput.forwardInputIndex(pointerId, count); 501 if (overwritesPrevCodePointByNodeCodePoint) { 502 mDicNodeState.mDicNodeStateInput.setPrevCodePoint(0, getNodeCodePoint()); 503 } 504 } 505 506 AK_FORCE_INLINE void updateInputIndexG(const DicNode_InputStateG *const inputStateG) { 507 if (mDicNodeState.mDicNodeStateOutput.getPrevWordCount() == 1 && isFirstLetter()) { 508 mDicNodeState.mDicNodeStateOutput.setSecondWordFirstInputIndex( 509 inputStateG->mInputIndex); 510 } 511 mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId, 512 inputStateG->mInputIndex, inputStateG->mPrevCodePoint, 513 inputStateG->mTerminalDiffCost, inputStateG->mRawLength); 514 mDicNodeState.mDicNodeStateScoring.addRawLength(inputStateG->mRawLength); 515 mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(inputStateG->mDoubleLetterLevel); 516 } 517 }; 518 } // namespace latinime 519 #endif // LATINIME_DIC_NODE_H 520