1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LATINIME_DIC_NODE_H 18 #define LATINIME_DIC_NODE_H 19 20 #include "char_utils.h" 21 #include "defines.h" 22 #include "dic_node_state.h" 23 #include "dic_node_profiler.h" 24 #include "dic_node_properties.h" 25 #include "dic_node_release_listener.h" 26 #include "digraph_utils.h" 27 28 #if DEBUG_DICT 29 #define LOGI_SHOW_ADD_COST_PROP \ 30 do { char charBuf[50]; \ 31 INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \ 32 AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \ 33 __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \ 34 getInputIndex(0), getNormalizedCompoundDistance(), charBuf); } while (0) 35 #define DUMP_WORD_AND_SCORE(header) \ 36 do { char charBuf[50]; char prevWordCharBuf[50]; \ 37 INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \ 38 INTS_TO_CHARS(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, \ 39 mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), prevWordCharBuf); \ 40 AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %s, %d,,", header, \ 41 getSpatialDistanceForScoring(), getLanguageDistanceForScoring(), \ 42 getNormalizedCompoundDistance(), getRawLength(), prevWordCharBuf, charBuf, \ 43 getInputIndex(0)); \ 44 } while (0) 45 #else 46 #define LOGI_SHOW_ADD_COST_PROP 47 #define DUMP_WORD_AND_SCORE(header) 48 #endif 49 50 namespace latinime { 51 52 // This struct is purely a bucket to return values. No instances of this struct should be kept. 53 struct DicNode_InputStateG { 54 bool mNeedsToUpdateInputStateG; 55 int mPointerId; 56 int16_t mInputIndex; 57 int mPrevCodePoint; 58 float mTerminalDiffCost; 59 float mRawLength; 60 DoubleLetterLevel mDoubleLetterLevel; 61 }; 62 63 class DicNode { 64 // Caveat: We define Weighting as a friend class of DicNode to let Weighting change 65 // the distance of DicNode. 66 // Caution!!! In general, we avoid using the "friend" access modifier. 67 // This is an exception to explicitly hide DicNode::addCost() from all classes but Weighting. 68 friend class Weighting; 69 70 public: 71 #if DEBUG_DICT 72 DicNodeProfiler mProfiler; 73 #endif 74 ////////////////// 75 // Memory utils // 76 ////////////////// 77 AK_FORCE_INLINE static void managedDelete(DicNode *node) { 78 node->remove(); 79 } 80 // end 81 ///////////////// 82 83 AK_FORCE_INLINE DicNode() 84 : 85 #if DEBUG_DICT 86 mProfiler(), 87 #endif 88 mDicNodeProperties(), mDicNodeState(), mIsCachedForNextSuggestion(false), 89 mIsUsed(false), mReleaseListener(0) {} 90 91 DicNode(const DicNode &dicNode); 92 DicNode &operator=(const DicNode &dicNode); 93 virtual ~DicNode() {} 94 95 // TODO: minimize arguments by looking binary_format 96 // Init for copy 97 void initByCopy(const DicNode *dicNode) { 98 mIsUsed = true; 99 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 100 mDicNodeProperties.init(&dicNode->mDicNodeProperties); 101 mDicNodeState.init(&dicNode->mDicNodeState); 102 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 103 } 104 105 // TODO: minimize arguments by looking binary_format 106 // Init for root with prevWordNodePos which is used for bigram 107 void initAsRoot(const int pos, const int childrenPos, const int childrenCount, 108 const int prevWordNodePos) { 109 mIsUsed = true; 110 mIsCachedForNextSuggestion = false; 111 mDicNodeProperties.init( 112 pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0); 113 mDicNodeState.init(prevWordNodePos); 114 PROF_NODE_RESET(mProfiler); 115 } 116 117 void initAsPassingChild(DicNode *parentNode) { 118 mIsUsed = true; 119 mIsCachedForNextSuggestion = parentNode->mIsCachedForNextSuggestion; 120 const int c = parentNode->getNodeTypedCodePoint(); 121 mDicNodeProperties.init(&parentNode->mDicNodeProperties, c); 122 mDicNodeState.init(&parentNode->mDicNodeState); 123 PROF_NODE_COPY(&parentNode->mProfiler, mProfiler); 124 } 125 126 // TODO: minimize arguments by looking binary_format 127 // Init for root with previous word 128 void initAsRootWithPreviousWord(DicNode *dicNode, const int pos, const int childrenPos, 129 const int childrenCount) { 130 mIsUsed = true; 131 mIsCachedForNextSuggestion = false; 132 mDicNodeProperties.init( 133 pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0); 134 // TODO: Move to dicNodeState? 135 mDicNodeState.mDicNodeStateOutput.init(); // reset for next word 136 mDicNodeState.mDicNodeStateInput.init( 137 &dicNode->mDicNodeState.mDicNodeStateInput, true /* resetTerminalDiffCost */); 138 mDicNodeState.mDicNodeStateScoring.init( 139 &dicNode->mDicNodeState.mDicNodeStateScoring); 140 mDicNodeState.mDicNodeStatePrevWord.init( 141 dicNode->mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() + 1, 142 dicNode->mDicNodeProperties.getProbability(), 143 dicNode->mDicNodeProperties.getPos(), 144 dicNode->mDicNodeState.mDicNodeStatePrevWord.mPrevWord, 145 dicNode->mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), 146 dicNode->getOutputWordBuf(), 147 dicNode->mDicNodeProperties.getDepth(), 148 dicNode->mDicNodeState.mDicNodeStatePrevWord.mPrevSpacePositions, 149 mDicNodeState.mDicNodeStateInput.getInputIndex(0) /* lastInputIndex */); 150 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 151 } 152 153 // TODO: minimize arguments by looking binary_format 154 void initAsChild(DicNode *dicNode, const int pos, const uint8_t flags, const int childrenPos, 155 const int attributesPos, const int siblingPos, const int nodeCodePoint, 156 const int childrenCount, const int probability, const int bigramProbability, 157 const bool isTerminal, const bool hasMultipleChars, const bool hasChildren, 158 const uint16_t additionalSubwordLength, const int *additionalSubword) { 159 mIsUsed = true; 160 uint16_t newDepth = static_cast<uint16_t>(dicNode->getDepth() + 1); 161 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 162 const uint16_t newLeavingDepth = static_cast<uint16_t>( 163 dicNode->mDicNodeProperties.getLeavingDepth() + additionalSubwordLength); 164 mDicNodeProperties.init(pos, flags, childrenPos, attributesPos, siblingPos, nodeCodePoint, 165 childrenCount, probability, bigramProbability, isTerminal, hasMultipleChars, 166 hasChildren, newDepth, newLeavingDepth); 167 mDicNodeState.init(&dicNode->mDicNodeState, additionalSubwordLength, additionalSubword); 168 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 169 } 170 171 AK_FORCE_INLINE void remove() { 172 mIsUsed = false; 173 if (mReleaseListener) { 174 mReleaseListener->onReleased(this); 175 } 176 } 177 178 bool isUsed() const { 179 return mIsUsed; 180 } 181 182 bool isRoot() const { 183 return getDepth() == 0; 184 } 185 186 bool hasChildren() const { 187 return mDicNodeProperties.hasChildren(); 188 } 189 190 bool isLeavingNode() const { 191 ASSERT(getDepth() <= getLeavingDepth()); 192 return getDepth() == getLeavingDepth(); 193 } 194 195 AK_FORCE_INLINE bool isFirstLetter() const { 196 return getDepth() == 1; 197 } 198 199 bool isCached() const { 200 return mIsCachedForNextSuggestion; 201 } 202 203 void setCached() { 204 mIsCachedForNextSuggestion = true; 205 } 206 207 // Used to expand the node in DicNodeUtils 208 int getNodeTypedCodePoint() const { 209 return mDicNodeState.mDicNodeStateOutput.getCodePointAt(getDepth()); 210 } 211 212 bool isImpossibleBigramWord() const { 213 if (mDicNodeProperties.hasBlacklistedOrNotAWordFlag()) { 214 return true; 215 } 216 const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength() 217 - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1; 218 const int currentWordLen = getDepth(); 219 return (prevWordLen == 1 && currentWordLen == 1); 220 } 221 222 bool isFirstCharUppercase() const { 223 const int c = getOutputWordBuf()[0]; 224 return isAsciiUpper(c); 225 } 226 227 bool isFirstWord() const { 228 return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos() == NOT_VALID_WORD; 229 } 230 231 bool isCompletion(const int inputSize) const { 232 return mDicNodeState.mDicNodeStateInput.getInputIndex(0) >= inputSize; 233 } 234 235 bool canDoLookAheadCorrection(const int inputSize) const { 236 return mDicNodeState.mDicNodeStateInput.getInputIndex(0) < inputSize - 1; 237 } 238 239 // Used to get bigram probability in DicNodeUtils 240 int getPos() const { 241 return mDicNodeProperties.getPos(); 242 } 243 244 // Used to get bigram probability in DicNodeUtils 245 int getPrevWordPos() const { 246 return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos(); 247 } 248 249 // Used in DicNodeUtils 250 int getChildrenPos() const { 251 return mDicNodeProperties.getChildrenPos(); 252 } 253 254 // Used in DicNodeUtils 255 int getChildrenCount() const { 256 return mDicNodeProperties.getChildrenCount(); 257 } 258 259 // Used in DicNodeUtils 260 int getProbability() const { 261 return mDicNodeProperties.getProbability(); 262 } 263 264 AK_FORCE_INLINE bool isTerminalWordNode() const { 265 const bool isTerminalNodes = mDicNodeProperties.isTerminal(); 266 const int currentNodeDepth = getDepth(); 267 const int terminalNodeDepth = mDicNodeProperties.getLeavingDepth(); 268 return isTerminalNodes && currentNodeDepth > 0 && currentNodeDepth == terminalNodeDepth; 269 } 270 271 bool shouldBeFilterdBySafetyNetForBigram() const { 272 const uint16_t currentDepth = getDepth(); 273 const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength() 274 - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1; 275 return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1)); 276 } 277 278 uint16_t getLeavingDepth() const { 279 return mDicNodeProperties.getLeavingDepth(); 280 } 281 282 bool isTotalInputSizeExceedingLimit() const { 283 const int prevWordsLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); 284 const int currentWordDepth = getDepth(); 285 // TODO: 3 can be 2? Needs to be investigated. 286 // TODO: Have a const variable for 3 (or 2) 287 return prevWordsLen + currentWordDepth > MAX_WORD_LENGTH - 3; 288 } 289 290 // TODO: This may be defective. Needs to be revised. 291 bool truncateNode(const DicNode *const topNode, const int inputCommitPoint) { 292 const int prevWordLenOfTop = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); 293 int newPrevWordStartIndex = inputCommitPoint; 294 int charCount = 0; 295 // Find new word start index 296 for (int i = 0; i < prevWordLenOfTop; ++i) { 297 const int c = mDicNodeState.mDicNodeStatePrevWord.getPrevWordCodePointAt(i); 298 // TODO: Check other separators. 299 if (c != KEYCODE_SPACE && c != KEYCODE_SINGLE_QUOTE) { 300 if (charCount == inputCommitPoint) { 301 newPrevWordStartIndex = i; 302 break; 303 } 304 ++charCount; 305 } 306 } 307 if (!mDicNodeState.mDicNodeStatePrevWord.startsWith( 308 &topNode->mDicNodeState.mDicNodeStatePrevWord, newPrevWordStartIndex - 1)) { 309 // Node mismatch. 310 return false; 311 } 312 mDicNodeState.mDicNodeStateInput.truncate(inputCommitPoint); 313 mDicNodeState.mDicNodeStatePrevWord.truncate(newPrevWordStartIndex); 314 return true; 315 } 316 317 void outputResult(int *dest) const { 318 const uint16_t prevWordLength = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); 319 const uint16_t currentDepth = getDepth(); 320 DicNodeUtils::appendTwoWords(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, 321 prevWordLength, getOutputWordBuf(), currentDepth, dest); 322 DUMP_WORD_AND_SCORE("OUTPUT"); 323 } 324 325 void outputSpacePositionsResult(int *spaceIndices) const { 326 mDicNodeState.mDicNodeStatePrevWord.outputSpacePositions(spaceIndices); 327 } 328 329 bool hasMultipleWords() const { 330 return mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() > 0; 331 } 332 333 float getProximityCorrectionCount() const { 334 return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount()); 335 } 336 337 float getEditCorrectionCount() const { 338 return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount()); 339 } 340 341 // Used to prune nodes 342 float getNormalizedCompoundDistance() const { 343 return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistance(); 344 } 345 346 // Used to prune nodes 347 float getNormalizedSpatialDistance() const { 348 return mDicNodeState.mDicNodeStateScoring.getSpatialDistance() 349 / static_cast<float>(getInputIndex(0) + 1); 350 } 351 352 // Used to prune nodes 353 float getCompoundDistance() const { 354 return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(); 355 } 356 357 // Used to prune nodes 358 float getCompoundDistance(const float languageWeight) const { 359 return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight); 360 } 361 362 // Used to commit input partially 363 int getPrevWordNodePos() const { 364 return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos(); 365 } 366 367 AK_FORCE_INLINE const int *getOutputWordBuf() const { 368 return mDicNodeState.mDicNodeStateOutput.mWordBuf; 369 } 370 371 int getPrevCodePointG(int pointerId) const { 372 return mDicNodeState.mDicNodeStateInput.getPrevCodePoint(pointerId); 373 } 374 375 // Whether the current codepoint can be an intentional omission, in which case the traversal 376 // algorithm will always check for a possible omission here. 377 bool canBeIntentionalOmission() const { 378 return isIntentionalOmissionCodePoint(getNodeCodePoint()); 379 } 380 381 // Whether the omission is so frequent that it should incur zero cost. 382 bool isZeroCostOmission() const { 383 // TODO: do not hardcode and read from header 384 return (getNodeCodePoint() == KEYCODE_SINGLE_QUOTE); 385 } 386 387 // TODO: remove 388 float getTerminalDiffCostG(int path) const { 389 return mDicNodeState.mDicNodeStateInput.getTerminalDiffCost(path); 390 } 391 392 ////////////////////// 393 // Temporary getter // 394 // TODO: Remove // 395 ////////////////////// 396 // TODO: Remove once touch path is merged into ProximityInfoState 397 // Note: Returned codepoint may be a digraph codepoint if the node is in a composite glyph. 398 int getNodeCodePoint() const { 399 const int codePoint = mDicNodeProperties.getNodeCodePoint(); 400 const DigraphUtils::DigraphCodePointIndex digraphIndex = 401 mDicNodeState.mDicNodeStateScoring.getDigraphIndex(); 402 if (digraphIndex == DigraphUtils::NOT_A_DIGRAPH_INDEX) { 403 return codePoint; 404 } 405 return DigraphUtils::getDigraphCodePointForIndex(codePoint, digraphIndex); 406 } 407 408 //////////////////////////////// 409 // Utils for cost calculation // 410 //////////////////////////////// 411 AK_FORCE_INLINE bool isSameNodeCodePoint(const DicNode *const dicNode) const { 412 return mDicNodeProperties.getNodeCodePoint() 413 == dicNode->mDicNodeProperties.getNodeCodePoint(); 414 } 415 416 // TODO: remove 417 // TODO: rename getNextInputIndex 418 int16_t getInputIndex(int pointerId) const { 419 return mDicNodeState.mDicNodeStateInput.getInputIndex(pointerId); 420 } 421 422 //////////////////////////////////// 423 // Getter of features for scoring // 424 //////////////////////////////////// 425 float getSpatialDistanceForScoring() const { 426 return mDicNodeState.mDicNodeStateScoring.getSpatialDistance(); 427 } 428 429 float getLanguageDistanceForScoring() const { 430 return mDicNodeState.mDicNodeStateScoring.getLanguageDistance(); 431 } 432 433 float getLanguageDistanceRatePerWordForScoring() const { 434 const float langDist = getLanguageDistanceForScoring(); 435 const float totalWordCount = 436 static_cast<float>(mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() + 1); 437 return langDist / totalWordCount; 438 } 439 440 float getRawLength() const { 441 return mDicNodeState.mDicNodeStateScoring.getRawLength(); 442 } 443 444 bool isLessThanOneErrorForScoring() const { 445 return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount() 446 + mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount() <= 1; 447 } 448 449 DoubleLetterLevel getDoubleLetterLevel() const { 450 return mDicNodeState.mDicNodeStateScoring.getDoubleLetterLevel(); 451 } 452 453 void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) { 454 mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel); 455 } 456 457 bool isInDigraph() const { 458 return mDicNodeState.mDicNodeStateScoring.getDigraphIndex() 459 != DigraphUtils::NOT_A_DIGRAPH_INDEX; 460 } 461 462 void advanceDigraphIndex() { 463 mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex(); 464 } 465 466 bool isExactMatch() const { 467 return mDicNodeState.mDicNodeStateScoring.isExactMatch(); 468 } 469 470 uint8_t getFlags() const { 471 return mDicNodeProperties.getFlags(); 472 } 473 474 int getAttributesPos() const { 475 return mDicNodeProperties.getAttributesPos(); 476 } 477 478 inline uint16_t getDepth() const { 479 return mDicNodeProperties.getDepth(); 480 } 481 482 AK_FORCE_INLINE void dump(const char *tag) const { 483 #if DEBUG_DICT 484 DUMP_WORD_AND_SCORE(tag); 485 #if DEBUG_DUMP_ERROR 486 mProfiler.dump(); 487 #endif 488 #endif 489 } 490 491 void setReleaseListener(DicNodeReleaseListener *releaseListener) { 492 mReleaseListener = releaseListener; 493 } 494 495 AK_FORCE_INLINE bool compare(const DicNode *right) { 496 if (!isUsed() && !right->isUsed()) { 497 // Compare pointer values here for stable comparison 498 return this > right; 499 } 500 if (!isUsed()) { 501 return true; 502 } 503 if (!right->isUsed()) { 504 return false; 505 } 506 const float diff = 507 right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance(); 508 static const float MIN_DIFF = 0.000001f; 509 if (diff > MIN_DIFF) { 510 return true; 511 } else if (diff < -MIN_DIFF) { 512 return false; 513 } 514 const int depth = getDepth(); 515 const int depthDiff = right->getDepth() - depth; 516 if (depthDiff != 0) { 517 return depthDiff > 0; 518 } 519 for (int i = 0; i < depth; ++i) { 520 const int codePoint = mDicNodeState.mDicNodeStateOutput.getCodePointAt(i); 521 const int rightCodePoint = right->mDicNodeState.mDicNodeStateOutput.getCodePointAt(i); 522 if (codePoint != rightCodePoint) { 523 return rightCodePoint > codePoint; 524 } 525 } 526 // Compare pointer values here for stable comparison 527 return this > right; 528 } 529 530 private: 531 DicNodeProperties mDicNodeProperties; 532 DicNodeState mDicNodeState; 533 // TODO: Remove 534 bool mIsCachedForNextSuggestion; 535 bool mIsUsed; 536 DicNodeReleaseListener *mReleaseListener; 537 538 AK_FORCE_INLINE int getTotalInputIndex() const { 539 int index = 0; 540 for (int i = 0; i < MAX_POINTER_COUNT_G; i++) { 541 index += mDicNodeState.mDicNodeStateInput.getInputIndex(i); 542 } 543 return index; 544 } 545 546 // Caveat: Must not be called outside Weighting 547 // This restriction is guaranteed by "friend" 548 AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost, 549 const bool doNormalization, const int inputSize, const ErrorType errorType) { 550 if (DEBUG_GEO_FULL) { 551 LOGI_SHOW_ADD_COST_PROP; 552 } 553 mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization, 554 inputSize, getTotalInputIndex(), errorType); 555 } 556 557 // Caveat: Must not be called outside Weighting 558 // This restriction is guaranteed by "friend" 559 AK_FORCE_INLINE void forwardInputIndex(const int pointerId, const int count, 560 const bool overwritesPrevCodePointByNodeCodePoint) { 561 if (count == 0) { 562 return; 563 } 564 mDicNodeState.mDicNodeStateInput.forwardInputIndex(pointerId, count); 565 if (overwritesPrevCodePointByNodeCodePoint) { 566 mDicNodeState.mDicNodeStateInput.setPrevCodePoint(0, getNodeCodePoint()); 567 } 568 } 569 570 AK_FORCE_INLINE void updateInputIndexG(DicNode_InputStateG *inputStateG) { 571 mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId, 572 inputStateG->mInputIndex, inputStateG->mPrevCodePoint, 573 inputStateG->mTerminalDiffCost, inputStateG->mRawLength); 574 mDicNodeState.mDicNodeStateScoring.addRawLength(inputStateG->mRawLength); 575 mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(inputStateG->mDoubleLetterLevel); 576 } 577 }; 578 } // namespace latinime 579 #endif // LATINIME_DIC_NODE_H 580