1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LATINIME_DIC_NODE_H 18 #define LATINIME_DIC_NODE_H 19 20 #include "defines.h" 21 #include "suggest/core/dicnode/dic_node_profiler.h" 22 #include "suggest/core/dicnode/dic_node_utils.h" 23 #include "suggest/core/dicnode/internal/dic_node_state.h" 24 #include "suggest/core/dicnode/internal/dic_node_properties.h" 25 #include "suggest/core/dictionary/digraph_utils.h" 26 #include "suggest/core/dictionary/error_type_utils.h" 27 #include "suggest/core/layout/proximity_info_state.h" 28 #include "utils/char_utils.h" 29 #include "utils/int_array_view.h" 30 31 #if DEBUG_DICT 32 #define LOGI_SHOW_ADD_COST_PROP \ 33 do { \ 34 char charBuf[50]; \ 35 INTS_TO_CHARS(getOutputWordBuf(), getNodeCodePointCount(), charBuf, NELEMS(charBuf)); \ 36 AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \ 37 __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \ 38 getInputIndex(0), getNormalizedCompoundDistance(), charBuf); \ 39 } while (0) 40 #define DUMP_WORD_AND_SCORE(header) \ 41 do { \ 42 char charBuf[50]; \ 43 INTS_TO_CHARS(getOutputWordBuf(), \ 44 getNodeCodePointCount() \ 45 + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength(), \ 46 charBuf, NELEMS(charBuf)); \ 47 AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %d, %5f,", header, \ 48 getSpatialDistanceForScoring(), \ 49 mDicNodeState.mDicNodeStateScoring.getLanguageDistance(), \ 50 getNormalizedCompoundDistance(), getRawLength(), charBuf, \ 51 getInputIndex(0), getNormalizedCompoundDistanceAfterFirstWord()); \ 52 } while (0) 53 #else 54 #define LOGI_SHOW_ADD_COST_PROP 55 #define DUMP_WORD_AND_SCORE(header) 56 #endif 57 58 namespace latinime { 59 60 // This struct is purely a bucket to return values. No instances of this struct should be kept. 61 struct DicNode_InputStateG { 62 DicNode_InputStateG() 63 : mNeedsToUpdateInputStateG(false), mPointerId(0), mInputIndex(0), 64 mPrevCodePoint(0), mTerminalDiffCost(0.0f), mRawLength(0.0f), 65 mDoubleLetterLevel(NOT_A_DOUBLE_LETTER) {} 66 67 bool mNeedsToUpdateInputStateG; 68 int mPointerId; 69 int16_t mInputIndex; 70 int mPrevCodePoint; 71 float mTerminalDiffCost; 72 float mRawLength; 73 DoubleLetterLevel mDoubleLetterLevel; 74 }; 75 76 class DicNode { 77 // Caveat: We define Weighting as a friend class of DicNode to let Weighting change 78 // the distance of DicNode. 79 // Caution!!! In general, we avoid using the "friend" access modifier. 80 // This is an exception to explicitly hide DicNode::addCost() from all classes but Weighting. 81 friend class Weighting; 82 83 public: 84 #if DEBUG_DICT 85 DicNodeProfiler mProfiler; 86 #endif 87 88 AK_FORCE_INLINE DicNode() 89 : 90 #if DEBUG_DICT 91 mProfiler(), 92 #endif 93 mDicNodeProperties(), mDicNodeState(), mIsCachedForNextSuggestion(false) {} 94 95 DicNode(const DicNode &dicNode); 96 DicNode &operator=(const DicNode &dicNode); 97 ~DicNode() {} 98 99 // Init for copy 100 void initByCopy(const DicNode *const dicNode) { 101 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 102 mDicNodeProperties.initByCopy(&dicNode->mDicNodeProperties); 103 mDicNodeState.initByCopy(&dicNode->mDicNodeState); 104 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 105 } 106 107 // Init for root with prevWordIds which is used for n-gram 108 void initAsRoot(const int rootPtNodeArrayPos, const WordIdArrayView prevWordIds) { 109 mIsCachedForNextSuggestion = false; 110 mDicNodeProperties.init(rootPtNodeArrayPos, prevWordIds); 111 mDicNodeState.init(); 112 PROF_NODE_RESET(mProfiler); 113 } 114 115 // Init for root with previous word 116 void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) { 117 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 118 WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> newPrevWordIds; 119 newPrevWordIds[0] = dicNode->mDicNodeProperties.getWordId(); 120 dicNode->getPrevWordIds().limit(newPrevWordIds.size() - 1) 121 .copyToArray(&newPrevWordIds, 1 /* offset */); 122 mDicNodeProperties.init(rootPtNodeArrayPos, WordIdArrayView::fromArray(newPrevWordIds)); 123 mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState, 124 dicNode->mDicNodeProperties.getDepth()); 125 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 126 } 127 128 void initAsPassingChild(const DicNode *parentDicNode) { 129 mIsCachedForNextSuggestion = parentDicNode->mIsCachedForNextSuggestion; 130 const int codePoint = 131 parentDicNode->mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt( 132 parentDicNode->getNodeCodePointCount()); 133 mDicNodeProperties.init(&parentDicNode->mDicNodeProperties, codePoint); 134 mDicNodeState.initByCopy(&parentDicNode->mDicNodeState); 135 PROF_NODE_COPY(&parentDicNode->mProfiler, mProfiler); 136 } 137 138 void initAsChild(const DicNode *const dicNode, const int childrenPtNodeArrayPos, 139 const int wordId, const CodePointArrayView mergedCodePoints) { 140 uint16_t newDepth = static_cast<uint16_t>(dicNode->getNodeCodePointCount() + 1); 141 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 142 const uint16_t newLeavingDepth = static_cast<uint16_t>( 143 dicNode->mDicNodeProperties.getLeavingDepth() + mergedCodePoints.size()); 144 mDicNodeProperties.init(childrenPtNodeArrayPos, mergedCodePoints[0], 145 wordId, newDepth, newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordIds()); 146 mDicNodeState.init(&dicNode->mDicNodeState, mergedCodePoints.size(), 147 mergedCodePoints.data()); 148 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 149 } 150 151 bool isRoot() const { 152 return getNodeCodePointCount() == 0; 153 } 154 155 bool hasChildren() const { 156 return mDicNodeProperties.hasChildren(); 157 } 158 159 bool isLeavingNode() const { 160 ASSERT(getNodeCodePointCount() <= mDicNodeProperties.getLeavingDepth()); 161 return getNodeCodePointCount() == mDicNodeProperties.getLeavingDepth(); 162 } 163 164 AK_FORCE_INLINE bool isFirstLetter() const { 165 return getNodeCodePointCount() == 1; 166 } 167 168 bool isCached() const { 169 return mIsCachedForNextSuggestion; 170 } 171 172 void setCached() { 173 mIsCachedForNextSuggestion = true; 174 } 175 176 // Check if the current word and the previous word can be considered as a valid multiple word 177 // suggestion. 178 bool isValidMultipleWordSuggestion() const { 179 // Treat suggestion as invalid if the current and the previous word are single character 180 // words. 181 const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength() 182 - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1; 183 const int currentWordLen = getNodeCodePointCount(); 184 return (prevWordLen != 1 || currentWordLen != 1); 185 } 186 187 bool isFirstCharUppercase() const { 188 const int c = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(0); 189 return CharUtils::isAsciiUpper(c); 190 } 191 192 bool isCompletion(const int inputSize) const { 193 return mDicNodeState.mDicNodeStateInput.getInputIndex(0) >= inputSize; 194 } 195 196 bool canDoLookAheadCorrection(const int inputSize) const { 197 return mDicNodeState.mDicNodeStateInput.getInputIndex(0) < inputSize - 1; 198 } 199 200 // Used to get n-gram probability in DicNodeUtils. 201 int getWordId() const { 202 return mDicNodeProperties.getWordId(); 203 } 204 205 const WordIdArrayView getPrevWordIds() const { 206 return mDicNodeProperties.getPrevWordIds(); 207 } 208 209 // Used in DicNodeUtils 210 int getChildrenPtNodeArrayPos() const { 211 return mDicNodeProperties.getChildrenPtNodeArrayPos(); 212 } 213 214 AK_FORCE_INLINE bool isTerminalDicNode() const { 215 const bool isTerminalPtNode = mDicNodeProperties.isTerminal(); 216 const int currentDicNodeDepth = getNodeCodePointCount(); 217 const int terminalDicNodeDepth = mDicNodeProperties.getLeavingDepth(); 218 return isTerminalPtNode && currentDicNodeDepth > 0 219 && currentDicNodeDepth == terminalDicNodeDepth; 220 } 221 222 bool shouldBeFilteredBySafetyNetForBigram() const { 223 const uint16_t currentDepth = getNodeCodePointCount(); 224 const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength() 225 - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1; 226 return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1)); 227 } 228 229 bool hasMatchedOrProximityCodePoints() const { 230 // This DicNode does not have matched or proximity code points when all code points have 231 // been handled as edit corrections or completion so far. 232 const int editCorrectionCount = mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount(); 233 const int completionCount = mDicNodeState.mDicNodeStateScoring.getCompletionCount(); 234 return (editCorrectionCount + completionCount) < getNodeCodePointCount(); 235 } 236 237 bool isTotalInputSizeExceedingLimit() const { 238 // TODO: 3 can be 2? Needs to be investigated. 239 // TODO: Have a const variable for 3 (or 2) 240 return getTotalNodeCodePointCount() > MAX_WORD_LENGTH - 3; 241 } 242 243 void outputResult(int *dest) const { 244 memmove(dest, getOutputWordBuf(), getTotalNodeCodePointCount() * sizeof(dest[0])); 245 DUMP_WORD_AND_SCORE("OUTPUT"); 246 } 247 248 // "Total" in this context (and other methods in this class) means the whole suggestion. When 249 // this represents a multi-word suggestion, the referenced PtNode (in mDicNodeState) is only 250 // the one that corresponds to the last word of the suggestion, and all the previous words 251 // are concatenated together in mDicNodeStateOutput. 252 int getTotalNodeSpaceCount() const { 253 if (!hasMultipleWords()) { 254 return 0; 255 } 256 return CharUtils::getSpaceCount(mDicNodeState.mDicNodeStateOutput.getCodePointBuf(), 257 mDicNodeState.mDicNodeStateOutput.getPrevWordsLength()); 258 } 259 260 int getSecondWordFirstInputIndex(const ProximityInfoState *const pInfoState) const { 261 const int inputIndex = mDicNodeState.mDicNodeStateOutput.getSecondWordFirstInputIndex(); 262 if (inputIndex == NOT_AN_INDEX) { 263 return NOT_AN_INDEX; 264 } else { 265 return pInfoState->getInputIndexOfSampledPoint(inputIndex); 266 } 267 } 268 269 bool hasMultipleWords() const { 270 return mDicNodeState.mDicNodeStateOutput.getPrevWordCount() > 0; 271 } 272 273 int getProximityCorrectionCount() const { 274 return mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount(); 275 } 276 277 int getEditCorrectionCount() const { 278 return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount(); 279 } 280 281 // Used to prune nodes 282 float getNormalizedCompoundDistance() const { 283 return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistance(); 284 } 285 286 // Used to prune nodes 287 float getNormalizedSpatialDistance() const { 288 return mDicNodeState.mDicNodeStateScoring.getSpatialDistance() 289 / static_cast<float>(getInputIndex(0) + 1); 290 } 291 292 // Used to prune nodes 293 float getCompoundDistance() const { 294 return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(); 295 } 296 297 // Used to prune nodes 298 float getCompoundDistance(const float weightOfLangModelVsSpatialModel) const { 299 return mDicNodeState.mDicNodeStateScoring.getCompoundDistance( 300 weightOfLangModelVsSpatialModel); 301 } 302 303 AK_FORCE_INLINE const int *getOutputWordBuf() const { 304 return mDicNodeState.mDicNodeStateOutput.getCodePointBuf(); 305 } 306 307 int getPrevCodePointG(int pointerId) const { 308 return mDicNodeState.mDicNodeStateInput.getPrevCodePoint(pointerId); 309 } 310 311 // Whether the current codepoint can be an intentional omission, in which case the traversal 312 // algorithm will always check for a possible omission here. 313 bool canBeIntentionalOmission() const { 314 return CharUtils::isIntentionalOmissionCodePoint(getNodeCodePoint()); 315 } 316 317 // Whether the omission is so frequent that it should incur zero cost. 318 bool isZeroCostOmission() const { 319 // TODO: do not hardcode and read from header 320 return (getNodeCodePoint() == KEYCODE_SINGLE_QUOTE); 321 } 322 323 // TODO: remove 324 float getTerminalDiffCostG(int path) const { 325 return mDicNodeState.mDicNodeStateInput.getTerminalDiffCost(path); 326 } 327 328 ////////////////////// 329 // Temporary getter // 330 // TODO: Remove // 331 ////////////////////// 332 // TODO: Remove once touch path is merged into ProximityInfoState 333 // Note: Returned codepoint may be a digraph codepoint if the node is in a composite glyph. 334 int getNodeCodePoint() const { 335 const int codePoint = mDicNodeProperties.getDicNodeCodePoint(); 336 const DigraphUtils::DigraphCodePointIndex digraphIndex = 337 mDicNodeState.mDicNodeStateScoring.getDigraphIndex(); 338 if (digraphIndex == DigraphUtils::NOT_A_DIGRAPH_INDEX) { 339 return codePoint; 340 } 341 return DigraphUtils::getDigraphCodePointForIndex(codePoint, digraphIndex); 342 } 343 344 //////////////////////////////// 345 // Utils for cost calculation // 346 //////////////////////////////// 347 AK_FORCE_INLINE bool isSameNodeCodePoint(const DicNode *const dicNode) const { 348 return mDicNodeProperties.getDicNodeCodePoint() 349 == dicNode->mDicNodeProperties.getDicNodeCodePoint(); 350 } 351 352 // TODO: remove 353 // TODO: rename getNextInputIndex 354 int16_t getInputIndex(int pointerId) const { 355 return mDicNodeState.mDicNodeStateInput.getInputIndex(pointerId); 356 } 357 358 //////////////////////////////////// 359 // Getter of features for scoring // 360 //////////////////////////////////// 361 float getSpatialDistanceForScoring() const { 362 return mDicNodeState.mDicNodeStateScoring.getSpatialDistance(); 363 } 364 365 // For space-aware gestures, we store the normalized distance at the char index 366 // that ends the first word of the suggestion. We call this the distance after 367 // first word. 368 float getNormalizedCompoundDistanceAfterFirstWord() const { 369 return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistanceAfterFirstWord(); 370 } 371 372 float getRawLength() const { 373 return mDicNodeState.mDicNodeStateScoring.getRawLength(); 374 } 375 376 DoubleLetterLevel getDoubleLetterLevel() const { 377 return mDicNodeState.mDicNodeStateScoring.getDoubleLetterLevel(); 378 } 379 380 void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) { 381 mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel); 382 } 383 384 bool isInDigraph() const { 385 return mDicNodeState.mDicNodeStateScoring.getDigraphIndex() 386 != DigraphUtils::NOT_A_DIGRAPH_INDEX; 387 } 388 389 void advanceDigraphIndex() { 390 mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex(); 391 } 392 393 ErrorTypeUtils::ErrorType getContainedErrorTypes() const { 394 return mDicNodeState.mDicNodeStateScoring.getContainedErrorTypes(); 395 } 396 397 inline uint16_t getNodeCodePointCount() const { 398 return mDicNodeProperties.getDepth(); 399 } 400 401 // Returns code point count including spaces 402 inline uint16_t getTotalNodeCodePointCount() const { 403 return getNodeCodePointCount() + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength(); 404 } 405 406 AK_FORCE_INLINE void dump(const char *tag) const { 407 #if DEBUG_DICT 408 DUMP_WORD_AND_SCORE(tag); 409 #if DEBUG_DUMP_ERROR 410 mProfiler.dump(); 411 #endif 412 #endif 413 } 414 415 AK_FORCE_INLINE bool compare(const DicNode *right) const { 416 // Promote exact matches to prevent them from being pruned. 417 const bool leftExactMatch = ErrorTypeUtils::isExactMatch(getContainedErrorTypes()); 418 const bool rightExactMatch = ErrorTypeUtils::isExactMatch(right->getContainedErrorTypes()); 419 if (leftExactMatch != rightExactMatch) { 420 return leftExactMatch; 421 } 422 const float diff = 423 right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance(); 424 static const float MIN_DIFF = 0.000001f; 425 if (diff > MIN_DIFF) { 426 return true; 427 } else if (diff < -MIN_DIFF) { 428 return false; 429 } 430 const int depth = getNodeCodePointCount(); 431 const int depthDiff = right->getNodeCodePointCount() - depth; 432 if (depthDiff != 0) { 433 return depthDiff > 0; 434 } 435 for (int i = 0; i < depth; ++i) { 436 const int codePoint = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i); 437 const int rightCodePoint = 438 right->mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i); 439 if (codePoint != rightCodePoint) { 440 return rightCodePoint > codePoint; 441 } 442 } 443 // Compare pointer values here for stable comparison 444 return this > right; 445 } 446 447 private: 448 DicNodeProperties mDicNodeProperties; 449 DicNodeState mDicNodeState; 450 // TODO: Remove 451 bool mIsCachedForNextSuggestion; 452 453 AK_FORCE_INLINE int getTotalInputIndex() const { 454 int index = 0; 455 for (int i = 0; i < MAX_POINTER_COUNT_G; i++) { 456 index += mDicNodeState.mDicNodeStateInput.getInputIndex(i); 457 } 458 return index; 459 } 460 461 // Caveat: Must not be called outside Weighting 462 // This restriction is guaranteed by "friend" 463 AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost, 464 const bool doNormalization, const int inputSize, 465 const ErrorTypeUtils::ErrorType errorType) { 466 if (DEBUG_GEO_FULL) { 467 LOGI_SHOW_ADD_COST_PROP; 468 } 469 mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization, 470 inputSize, getTotalInputIndex(), errorType); 471 } 472 473 // Saves the current normalized compound distance for space-aware gestures. 474 // See getNormalizedCompoundDistanceAfterFirstWord for details. 475 AK_FORCE_INLINE void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() { 476 mDicNodeState.mDicNodeStateScoring.saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet(); 477 } 478 479 // Caveat: Must not be called outside Weighting 480 // This restriction is guaranteed by "friend" 481 AK_FORCE_INLINE void forwardInputIndex(const int pointerId, const int count, 482 const bool overwritesPrevCodePointByNodeCodePoint) { 483 if (count == 0) { 484 return; 485 } 486 mDicNodeState.mDicNodeStateInput.forwardInputIndex(pointerId, count); 487 if (overwritesPrevCodePointByNodeCodePoint) { 488 mDicNodeState.mDicNodeStateInput.setPrevCodePoint(0, getNodeCodePoint()); 489 } 490 } 491 492 AK_FORCE_INLINE void updateInputIndexG(const DicNode_InputStateG *const inputStateG) { 493 if (mDicNodeState.mDicNodeStateOutput.getPrevWordCount() == 1 && isFirstLetter()) { 494 mDicNodeState.mDicNodeStateOutput.setSecondWordFirstInputIndex( 495 inputStateG->mInputIndex); 496 } 497 mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId, 498 inputStateG->mInputIndex, inputStateG->mPrevCodePoint, 499 inputStateG->mTerminalDiffCost, inputStateG->mRawLength); 500 mDicNodeState.mDicNodeStateScoring.addRawLength(inputStateG->mRawLength); 501 mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(inputStateG->mDoubleLetterLevel); 502 } 503 }; 504 } // namespace latinime 505 #endif // LATINIME_DIC_NODE_H 506