1 /* 2 * Copyright (C) 2011 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LATINIME_BINARY_FORMAT_H 18 #define LATINIME_BINARY_FORMAT_H 19 20 #include <cstdlib> 21 #include <map> 22 #include <stdint.h> 23 24 #include "bloom_filter.h" 25 #include "char_utils.h" 26 #include "hash_map_compat.h" 27 28 namespace latinime { 29 30 class BinaryFormat { 31 public: 32 // Mask and flags for children address type selection. 33 static const int MASK_GROUP_ADDRESS_TYPE = 0xC0; 34 35 // Flag for single/multiple char group 36 static const int FLAG_HAS_MULTIPLE_CHARS = 0x20; 37 38 // Flag for terminal groups 39 static const int FLAG_IS_TERMINAL = 0x10; 40 41 // Flag for shortcut targets presence 42 static const int FLAG_HAS_SHORTCUT_TARGETS = 0x08; 43 // Flag for bigram presence 44 static const int FLAG_HAS_BIGRAMS = 0x04; 45 // Flag for non-words (typically, shortcut only entries) 46 static const int FLAG_IS_NOT_A_WORD = 0x02; 47 // Flag for blacklist 48 static const int FLAG_IS_BLACKLISTED = 0x01; 49 50 // Attribute (bigram/shortcut) related flags: 51 // Flag for presence of more attributes 52 static const int FLAG_ATTRIBUTE_HAS_NEXT = 0x80; 53 // Flag for sign of offset. If this flag is set, the offset value must be negated. 54 static const int FLAG_ATTRIBUTE_OFFSET_NEGATIVE = 0x40; 55 56 // Mask for attribute probability, stored on 4 bits inside the flags byte. 57 static const int MASK_ATTRIBUTE_PROBABILITY = 0x0F; 58 // The numeric value of the shortcut probability that means 'whitelist'. 59 static const int WHITELIST_SHORTCUT_PROBABILITY = 15; 60 61 // Mask and flags for attribute address type selection. 62 static const int MASK_ATTRIBUTE_ADDRESS_TYPE = 0x30; 63 64 static const int UNKNOWN_FORMAT = -1; 65 static const int SHORTCUT_LIST_SIZE_SIZE = 2; 66 67 static int detectFormat(const uint8_t *const dict, const int dictSize); 68 static int getHeaderSize(const uint8_t *const dict, const int dictSize); 69 static int getFlags(const uint8_t *const dict, const int dictSize); 70 static bool hasBlacklistedOrNotAWordFlag(const int flags); 71 static void readHeaderValue(const uint8_t *const dict, const int dictSize, 72 const char *const key, int *outValue, const int outValueSize); 73 static int readHeaderValueInt(const uint8_t *const dict, const int dictSize, 74 const char *const key); 75 static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos); 76 static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos); 77 static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos); 78 static int readProbabilityWithoutMovingPointer(const uint8_t *const dict, const int pos); 79 static int skipOtherCharacters(const uint8_t *const dict, const int pos); 80 static int skipChildrenPosition(const uint8_t flags, const int pos); 81 static int skipProbability(const uint8_t flags, const int pos); 82 static int skipShortcuts(const uint8_t *const dict, const uint8_t flags, const int pos); 83 static int skipChildrenPosAndAttributes(const uint8_t *const dict, const uint8_t flags, 84 const int pos); 85 static int readChildrenPosition(const uint8_t *const dict, const uint8_t flags, const int pos); 86 static bool hasChildrenInFlags(const uint8_t flags); 87 static int getAttributeAddressAndForwardPointer(const uint8_t *const dict, const uint8_t flags, 88 int *pos); 89 static int getAttributeProbabilityFromFlags(const int flags); 90 static int getTerminalPosition(const uint8_t *const root, const int *const inWord, 91 const int length, const bool forceLowerCaseSearch); 92 static int getWordAtAddress(const uint8_t *const root, const int address, const int maxDepth, 93 int *outWord, int *outUnigramProbability); 94 static int computeProbabilityForBigram( 95 const int unigramProbability, const int bigramProbability); 96 static int getProbability(const int position, const std::map<int, int> *bigramMap, 97 const uint8_t *bigramFilter, const int unigramProbability); 98 static int getBigramProbabilityFromHashMap(const int position, 99 const hash_map_compat<int, int> *bigramMap, const int unigramProbability); 100 static float getMultiWordCostMultiplier(const uint8_t *const dict, const int dictSize); 101 static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position, 102 hash_map_compat<int, int> *bigramMap); 103 static int getBigramProbability(const uint8_t *const root, int position, 104 const int nextPosition, const int unigramProbability); 105 106 // Flags for special processing 107 // Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or 108 // something very bad (like, the apocalypse) will happen. Please update both at the same time. 109 enum { 110 REQUIRES_GERMAN_UMLAUT_PROCESSING = 0x1, 111 REQUIRES_FRENCH_LIGATURES_PROCESSING = 0x4 112 }; 113 114 private: 115 DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryFormat); 116 static int getBigramListPositionForWordPosition(const uint8_t *const root, int position); 117 118 static const int FLAG_GROUP_ADDRESS_TYPE_NOADDRESS = 0x00; 119 static const int FLAG_GROUP_ADDRESS_TYPE_ONEBYTE = 0x40; 120 static const int FLAG_GROUP_ADDRESS_TYPE_TWOBYTES = 0x80; 121 static const int FLAG_GROUP_ADDRESS_TYPE_THREEBYTES = 0xC0; 122 static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE = 0x10; 123 static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; 124 static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; 125 126 // Any file smaller than this is not a dictionary. 127 static const int DICTIONARY_MINIMUM_SIZE = 4; 128 // Originally, format version 1 had a 16-bit magic number, then the version number `01' 129 // then options that must be 0. Hence the first 32-bits of the format are always as follow 130 // and it's okay to consider them a magic number as a whole. 131 static const int FORMAT_VERSION_1_MAGIC_NUMBER = 0x78B10100; 132 static const int FORMAT_VERSION_1_HEADER_SIZE = 5; 133 // The versions of Latin IME that only handle format version 1 only test for the magic 134 // number, so we had to change it so that version 2 files would be rejected by older 135 // implementations. On this occasion, we made the magic number 32 bits long. 136 static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE 137 // Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12 138 static const int FORMAT_VERSION_2_MINIMUM_SIZE = 12; 139 140 static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1; 141 static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20; 142 static const int CHARACTER_ARRAY_TERMINATOR = 0x1F; 143 static const int MULTIPLE_BYTE_CHARACTER_ADDITIONAL_SIZE = 2; 144 static const int NO_FLAGS = 0; 145 static int skipAllAttributes(const uint8_t *const dict, const uint8_t flags, const int pos); 146 static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos); 147 }; 148 149 AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict, const int dictSize) { 150 // The magic number is stored big-endian. 151 // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't 152 // understand this format. 153 if (dictSize < DICTIONARY_MINIMUM_SIZE) return UNKNOWN_FORMAT; 154 const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3]; 155 switch (magicNumber) { 156 case FORMAT_VERSION_1_MAGIC_NUMBER: 157 // Format 1 header is exactly 5 bytes long and looks like: 158 // Magic number (2 bytes) 0x78 0xB1 159 // Version number (1 byte) 0x01 160 // Options (2 bytes) must be 0x00 0x00 161 return 1; 162 case FORMAT_VERSION_2_MAGIC_NUMBER: 163 // Version 2 dictionaries are at least 12 bytes long (see below details for the header). 164 // If this dictionary has the version 2 magic number but is less than 12 bytes long, then 165 // it's an unknown format and we need to avoid confidently reading the next bytes. 166 if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) return UNKNOWN_FORMAT; 167 // Format 2 header is as follows: 168 // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE 169 // Version number (2 bytes) 0x00 0x02 170 // Options (2 bytes) 171 // Header size (4 bytes) : integer, big endian 172 return (dict[4] << 8) + dict[5]; 173 default: 174 return UNKNOWN_FORMAT; 175 } 176 } 177 178 inline int BinaryFormat::getFlags(const uint8_t *const dict, const int dictSize) { 179 switch (detectFormat(dict, dictSize)) { 180 case 1: 181 return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else? 182 default: 183 return (dict[6] << 8) + dict[7]; 184 } 185 } 186 187 inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) { 188 return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0; 189 } 190 191 inline int BinaryFormat::getHeaderSize(const uint8_t *const dict, const int dictSize) { 192 switch (detectFormat(dict, dictSize)) { 193 case 1: 194 return FORMAT_VERSION_1_HEADER_SIZE; 195 case 2: 196 // See the format of the header in the comment in detectFormat() above 197 return (dict[8] << 24) + (dict[9] << 16) + (dict[10] << 8) + dict[11]; 198 default: 199 return S_INT_MAX; 200 } 201 } 202 203 inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const int dictSize, 204 const char *const key, int *outValue, const int outValueSize) { 205 int outValueIndex = 0; 206 // Only format 2 and above have header attributes as {key,value} string pairs. For prior 207 // formats, we just return an empty string, as if the key wasn't found. 208 if (2 <= detectFormat(dict, dictSize)) { 209 const int headerOptionsOffset = 4 /* magic number */ 210 + 2 /* dictionary version */ + 2 /* flags */; 211 const int headerSize = 212 (dict[headerOptionsOffset] << 24) + (dict[headerOptionsOffset + 1] << 16) 213 + (dict[headerOptionsOffset + 2] << 8) + dict[headerOptionsOffset + 3]; 214 const int headerEnd = headerOptionsOffset + 4 + headerSize; 215 int index = headerOptionsOffset + 4; 216 while (index < headerEnd) { 217 int keyIndex = 0; 218 int codePoint = getCodePointAndForwardPointer(dict, &index); 219 while (codePoint != NOT_A_CODE_POINT) { 220 if (codePoint != key[keyIndex++]) { 221 break; 222 } 223 codePoint = getCodePointAndForwardPointer(dict, &index); 224 } 225 if (codePoint == NOT_A_CODE_POINT && key[keyIndex] == 0) { 226 // We found the key! Copy and return the value. 227 codePoint = getCodePointAndForwardPointer(dict, &index); 228 while (codePoint != NOT_A_CODE_POINT && outValueIndex < outValueSize) { 229 outValue[outValueIndex++] = codePoint; 230 codePoint = getCodePointAndForwardPointer(dict, &index); 231 } 232 // Finished copying. Break to go to the termination code. 233 break; 234 } 235 // We didn't find the key, skip the remainder of it and its value 236 while (codePoint != NOT_A_CODE_POINT) { 237 codePoint = getCodePointAndForwardPointer(dict, &index); 238 } 239 codePoint = getCodePointAndForwardPointer(dict, &index); 240 while (codePoint != NOT_A_CODE_POINT) { 241 codePoint = getCodePointAndForwardPointer(dict, &index); 242 } 243 } 244 // We couldn't find it - fall through and return an empty value. 245 } 246 // Put a terminator 0 if possible at all (always unless outValueSize is <= 0) 247 if (outValueIndex >= outValueSize) outValueIndex = outValueSize - 1; 248 if (outValueIndex >= 0) outValue[outValueIndex] = 0; 249 } 250 251 inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const int dictSize, 252 const char *const key) { 253 const int bufferSize = LARGEST_INT_DIGIT_COUNT; 254 int intBuffer[bufferSize]; 255 char charBuffer[bufferSize]; 256 BinaryFormat::readHeaderValue(dict, dictSize, key, intBuffer, bufferSize); 257 for (int i = 0; i < bufferSize; ++i) { 258 charBuffer[i] = intBuffer[i]; 259 } 260 // If not a number, return S_INT_MIN 261 if (!isdigit(charBuffer[0])) return S_INT_MIN; 262 return atoi(charBuffer); 263 } 264 265 AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t *const dict, 266 int *pos) { 267 const int msb = dict[(*pos)++]; 268 if (msb < 0x80) return msb; 269 return ((msb & 0x7F) << 8) | dict[(*pos)++]; 270 } 271 272 inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict, 273 const int dictSize) { 274 const int headerValue = readHeaderValueInt(dict, dictSize, 275 "MULTIPLE_WORDS_DEMOTION_RATE"); 276 if (headerValue == S_INT_MIN) { 277 return 1.0f; 278 } 279 if (headerValue <= 0) { 280 return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); 281 } 282 return 100.0f / static_cast<float>(headerValue); 283 } 284 285 inline uint8_t BinaryFormat::getFlagsAndForwardPointer(const uint8_t *const dict, int *pos) { 286 return dict[(*pos)++]; 287 } 288 289 AK_FORCE_INLINE int BinaryFormat::getCodePointAndForwardPointer(const uint8_t *const dict, 290 int *pos) { 291 const int origin = *pos; 292 const int codePoint = dict[origin]; 293 if (codePoint < MINIMAL_ONE_BYTE_CHARACTER_VALUE) { 294 if (codePoint == CHARACTER_ARRAY_TERMINATOR) { 295 *pos = origin + 1; 296 return NOT_A_CODE_POINT; 297 } else { 298 *pos = origin + 3; 299 const int char_1 = codePoint << 16; 300 const int char_2 = char_1 + (dict[origin + 1] << 8); 301 return char_2 + dict[origin + 2]; 302 } 303 } else { 304 *pos = origin + 1; 305 return codePoint; 306 } 307 } 308 309 inline int BinaryFormat::readProbabilityWithoutMovingPointer(const uint8_t *const dict, 310 const int pos) { 311 return dict[pos]; 312 } 313 314 AK_FORCE_INLINE int BinaryFormat::skipOtherCharacters(const uint8_t *const dict, const int pos) { 315 int currentPos = pos; 316 int character = dict[currentPos++]; 317 while (CHARACTER_ARRAY_TERMINATOR != character) { 318 if (character < MINIMAL_ONE_BYTE_CHARACTER_VALUE) { 319 currentPos += MULTIPLE_BYTE_CHARACTER_ADDITIONAL_SIZE; 320 } 321 character = dict[currentPos++]; 322 } 323 return currentPos; 324 } 325 326 static inline int attributeAddressSize(const uint8_t flags) { 327 static const int ATTRIBUTE_ADDRESS_SHIFT = 4; 328 return (flags & BinaryFormat::MASK_ATTRIBUTE_ADDRESS_TYPE) >> ATTRIBUTE_ADDRESS_SHIFT; 329 /* Note: this is a value-dependant optimization of what may probably be 330 more readably written this way: 331 switch (flags * BinaryFormat::MASK_ATTRIBUTE_ADDRESS_TYPE) { 332 case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: return 1; 333 case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: return 2; 334 case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTE: return 3; 335 default: return 0; 336 } 337 */ 338 } 339 340 static AK_FORCE_INLINE int skipExistingBigrams(const uint8_t *const dict, const int pos) { 341 int currentPos = pos; 342 uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dict, ¤tPos); 343 while (flags & BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT) { 344 currentPos += attributeAddressSize(flags); 345 flags = BinaryFormat::getFlagsAndForwardPointer(dict, ¤tPos); 346 } 347 currentPos += attributeAddressSize(flags); 348 return currentPos; 349 } 350 351 static inline int childrenAddressSize(const uint8_t flags) { 352 static const int CHILDREN_ADDRESS_SHIFT = 6; 353 return (BinaryFormat::MASK_GROUP_ADDRESS_TYPE & flags) >> CHILDREN_ADDRESS_SHIFT; 354 /* See the note in attributeAddressSize. The same applies here */ 355 } 356 357 static AK_FORCE_INLINE int shortcutByteSize(const uint8_t *const dict, const int pos) { 358 return (static_cast<int>(dict[pos] << 8)) + (dict[pos + 1]); 359 } 360 361 inline int BinaryFormat::skipChildrenPosition(const uint8_t flags, const int pos) { 362 return pos + childrenAddressSize(flags); 363 } 364 365 inline int BinaryFormat::skipProbability(const uint8_t flags, const int pos) { 366 return FLAG_IS_TERMINAL & flags ? pos + 1 : pos; 367 } 368 369 AK_FORCE_INLINE int BinaryFormat::skipShortcuts(const uint8_t *const dict, const uint8_t flags, 370 const int pos) { 371 if (FLAG_HAS_SHORTCUT_TARGETS & flags) { 372 return pos + shortcutByteSize(dict, pos); 373 } else { 374 return pos; 375 } 376 } 377 378 AK_FORCE_INLINE int BinaryFormat::skipBigrams(const uint8_t *const dict, const uint8_t flags, 379 const int pos) { 380 if (FLAG_HAS_BIGRAMS & flags) { 381 return skipExistingBigrams(dict, pos); 382 } else { 383 return pos; 384 } 385 } 386 387 AK_FORCE_INLINE int BinaryFormat::skipAllAttributes(const uint8_t *const dict, const uint8_t flags, 388 const int pos) { 389 // This function skips all attributes: shortcuts and bigrams. 390 int newPos = pos; 391 newPos = skipShortcuts(dict, flags, newPos); 392 newPos = skipBigrams(dict, flags, newPos); 393 return newPos; 394 } 395 396 AK_FORCE_INLINE int BinaryFormat::skipChildrenPosAndAttributes(const uint8_t *const dict, 397 const uint8_t flags, const int pos) { 398 int currentPos = pos; 399 currentPos = skipChildrenPosition(flags, currentPos); 400 currentPos = skipAllAttributes(dict, flags, currentPos); 401 return currentPos; 402 } 403 404 AK_FORCE_INLINE int BinaryFormat::readChildrenPosition(const uint8_t *const dict, 405 const uint8_t flags, const int pos) { 406 int offset = 0; 407 switch (MASK_GROUP_ADDRESS_TYPE & flags) { 408 case FLAG_GROUP_ADDRESS_TYPE_ONEBYTE: 409 offset = dict[pos]; 410 break; 411 case FLAG_GROUP_ADDRESS_TYPE_TWOBYTES: 412 offset = dict[pos] << 8; 413 offset += dict[pos + 1]; 414 break; 415 case FLAG_GROUP_ADDRESS_TYPE_THREEBYTES: 416 offset = dict[pos] << 16; 417 offset += dict[pos + 1] << 8; 418 offset += dict[pos + 2]; 419 break; 420 default: 421 // If we come here, it means we asked for the children of a word with 422 // no children. 423 return -1; 424 } 425 return pos + offset; 426 } 427 428 inline bool BinaryFormat::hasChildrenInFlags(const uint8_t flags) { 429 return (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS != (MASK_GROUP_ADDRESS_TYPE & flags)); 430 } 431 432 AK_FORCE_INLINE int BinaryFormat::getAttributeAddressAndForwardPointer(const uint8_t *const dict, 433 const uint8_t flags, int *pos) { 434 int offset = 0; 435 const int origin = *pos; 436 switch (MASK_ATTRIBUTE_ADDRESS_TYPE & flags) { 437 case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: 438 offset = dict[origin]; 439 *pos = origin + 1; 440 break; 441 case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: 442 offset = dict[origin] << 8; 443 offset += dict[origin + 1]; 444 *pos = origin + 2; 445 break; 446 case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES: 447 offset = dict[origin] << 16; 448 offset += dict[origin + 1] << 8; 449 offset += dict[origin + 2]; 450 *pos = origin + 3; 451 break; 452 } 453 if (FLAG_ATTRIBUTE_OFFSET_NEGATIVE & flags) { 454 return origin - offset; 455 } else { 456 return origin + offset; 457 } 458 } 459 460 inline int BinaryFormat::getAttributeProbabilityFromFlags(const int flags) { 461 return flags & MASK_ATTRIBUTE_PROBABILITY; 462 } 463 464 // This function gets the byte position of the last chargroup of the exact matching word in the 465 // dictionary. If no match is found, it returns NOT_VALID_WORD. 466 AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root, 467 const int *const inWord, const int length, const bool forceLowerCaseSearch) { 468 int pos = 0; 469 int wordPos = 0; 470 471 while (true) { 472 // If we already traversed the tree further than the word is long, there means 473 // there was no match (or we would have found it). 474 if (wordPos >= length) return NOT_VALID_WORD; 475 int charGroupCount = BinaryFormat::getGroupCountAndForwardPointer(root, &pos); 476 const int wChar = forceLowerCaseSearch ? toLowerCase(inWord[wordPos]) : inWord[wordPos]; 477 while (true) { 478 // If there are no more character groups in this node, it means we could not 479 // find a matching character for this depth, therefore there is no match. 480 if (0 >= charGroupCount) return NOT_VALID_WORD; 481 const int charGroupPos = pos; 482 const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); 483 int character = BinaryFormat::getCodePointAndForwardPointer(root, &pos); 484 if (character == wChar) { 485 // This is the correct node. Only one character group may start with the same 486 // char within a node, so either we found our match in this node, or there is 487 // no match and we can return NOT_VALID_WORD. So we will check all the characters 488 // in this character group indeed does match. 489 if (FLAG_HAS_MULTIPLE_CHARS & flags) { 490 character = BinaryFormat::getCodePointAndForwardPointer(root, &pos); 491 while (NOT_A_CODE_POINT != character) { 492 ++wordPos; 493 // If we shoot the length of the word we search for, or if we find a single 494 // character that does not match, as explained above, it means the word is 495 // not in the dictionary (by virtue of this chargroup being the only one to 496 // match the word on the first character, but not matching the whole word). 497 if (wordPos >= length) return NOT_VALID_WORD; 498 if (inWord[wordPos] != character) return NOT_VALID_WORD; 499 character = BinaryFormat::getCodePointAndForwardPointer(root, &pos); 500 } 501 } 502 // If we come here we know that so far, we do match. Either we are on a terminal 503 // and we match the length, in which case we found it, or we traverse children. 504 // If we don't match the length AND don't have children, then a word in the 505 // dictionary fully matches a prefix of the searched word but not the full word. 506 ++wordPos; 507 if (FLAG_IS_TERMINAL & flags) { 508 if (wordPos == length) { 509 return charGroupPos; 510 } 511 pos = BinaryFormat::skipProbability(FLAG_IS_TERMINAL, pos); 512 } 513 if (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS == (MASK_GROUP_ADDRESS_TYPE & flags)) { 514 return NOT_VALID_WORD; 515 } 516 // We have children and we are still shorter than the word we are searching for, so 517 // we need to traverse children. Put the pointer on the children position, and 518 // break 519 pos = BinaryFormat::readChildrenPosition(root, flags, pos); 520 break; 521 } else { 522 // This chargroup does not match, so skip the remaining part and go to the next. 523 if (FLAG_HAS_MULTIPLE_CHARS & flags) { 524 pos = BinaryFormat::skipOtherCharacters(root, pos); 525 } 526 pos = BinaryFormat::skipProbability(flags, pos); 527 pos = BinaryFormat::skipChildrenPosAndAttributes(root, flags, pos); 528 } 529 --charGroupCount; 530 } 531 } 532 } 533 534 // This function searches for a terminal in the dictionary by its address. 535 // Due to the fact that words are ordered in the dictionary in a strict breadth-first order, 536 // it is possible to check for this with advantageous complexity. For each node, we search 537 // for groups with children and compare the children address with the address we look for. 538 // When we shoot the address we look for, it means the word we look for is in the children 539 // of the previous group. The only tricky part is the fact that if we arrive at the end of a 540 // node with the last group's children address still less than what we are searching for, we 541 // must descend the last group's children (for example, if the word we are searching for starts 542 // with a z, it's the last group of the root node, so all children addresses will be smaller 543 // than the address we look for, and we have to descend the z node). 544 /* Parameters : 545 * root: the dictionary buffer 546 * address: the byte position of the last chargroup of the word we are searching for (this is 547 * what is stored as the "bigram address" in each bigram) 548 * outword: an array to write the found word, with MAX_WORD_LENGTH size. 549 * outUnigramProbability: a pointer to an int to write the probability into. 550 * Return value : the length of the word, of 0 if the word was not found. 551 */ 552 AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, const int address, 553 const int maxDepth, int *outWord, int *outUnigramProbability) { 554 int pos = 0; 555 int wordPos = 0; 556 557 // One iteration of the outer loop iterates through nodes. As stated above, we will only 558 // traverse nodes that are actually a part of the terminal we are searching, so each time 559 // we enter this loop we are one depth level further than last time. 560 // The only reason we count nodes is because we want to reduce the probability of infinite 561 // looping in case there is a bug. Since we know there is an upper bound to the depth we are 562 // supposed to traverse, it does not hurt to count iterations. 563 for (int loopCount = maxDepth; loopCount > 0; --loopCount) { 564 int lastCandidateGroupPos = 0; 565 // Let's loop through char groups in this node searching for either the terminal 566 // or one of its ascendants. 567 for (int charGroupCount = getGroupCountAndForwardPointer(root, &pos); charGroupCount > 0; 568 --charGroupCount) { 569 const int startPos = pos; 570 const uint8_t flags = getFlagsAndForwardPointer(root, &pos); 571 const int character = getCodePointAndForwardPointer(root, &pos); 572 if (address == startPos) { 573 // We found the address. Copy the rest of the word in the buffer and return 574 // the length. 575 outWord[wordPos] = character; 576 if (FLAG_HAS_MULTIPLE_CHARS & flags) { 577 int nextChar = getCodePointAndForwardPointer(root, &pos); 578 // We count chars in order to avoid infinite loops if the file is broken or 579 // if there is some other bug 580 int charCount = maxDepth; 581 while (NOT_A_CODE_POINT != nextChar && --charCount > 0) { 582 outWord[++wordPos] = nextChar; 583 nextChar = getCodePointAndForwardPointer(root, &pos); 584 } 585 } 586 *outUnigramProbability = readProbabilityWithoutMovingPointer(root, pos); 587 return ++wordPos; 588 } 589 // We need to skip past this char group, so skip any remaining chars after the 590 // first and possibly the probability. 591 if (FLAG_HAS_MULTIPLE_CHARS & flags) { 592 pos = skipOtherCharacters(root, pos); 593 } 594 pos = skipProbability(flags, pos); 595 596 // The fact that this group has children is very important. Since we already know 597 // that this group does not match, if it has no children we know it is irrelevant 598 // to what we are searching for. 599 const bool hasChildren = (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS != 600 (MASK_GROUP_ADDRESS_TYPE & flags)); 601 // We will write in `found' whether we have passed the children address we are 602 // searching for. For example if we search for "beer", the children of b are less 603 // than the address we are searching for and the children of c are greater. When we 604 // come here for c, we realize this is too big, and that we should descend b. 605 bool found; 606 if (hasChildren) { 607 // Here comes the tricky part. First, read the children position. 608 const int childrenPos = readChildrenPosition(root, flags, pos); 609 if (childrenPos > address) { 610 // If the children pos is greater than address, it means the previous chargroup, 611 // which address is stored in lastCandidateGroupPos, was the right one. 612 found = true; 613 } else if (1 >= charGroupCount) { 614 // However if we are on the LAST group of this node, and we have NOT shot the 615 // address we should descend THIS node. So we trick the lastCandidateGroupPos 616 // so that we will descend this node, not the previous one. 617 lastCandidateGroupPos = startPos; 618 found = true; 619 } else { 620 // Else, we should continue looking. 621 found = false; 622 } 623 } else { 624 // Even if we don't have children here, we could still be on the last group of this 625 // node. If this is the case, we should descend the last group that had children, 626 // and their address is already in lastCandidateGroup. 627 found = (1 >= charGroupCount); 628 } 629 630 if (found) { 631 // Okay, we found the group we should descend. Its address is in 632 // the lastCandidateGroupPos variable, so we just re-read it. 633 if (0 != lastCandidateGroupPos) { 634 const uint8_t lastFlags = 635 getFlagsAndForwardPointer(root, &lastCandidateGroupPos); 636 const int lastChar = 637 getCodePointAndForwardPointer(root, &lastCandidateGroupPos); 638 // We copy all the characters in this group to the buffer 639 outWord[wordPos] = lastChar; 640 if (FLAG_HAS_MULTIPLE_CHARS & lastFlags) { 641 int nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); 642 int charCount = maxDepth; 643 while (-1 != nextChar && --charCount > 0) { 644 outWord[++wordPos] = nextChar; 645 nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); 646 } 647 } 648 ++wordPos; 649 // Now we only need to branch to the children address. Skip the probability if 650 // it's there, read pos, and break to resume the search at pos. 651 lastCandidateGroupPos = skipProbability(lastFlags, lastCandidateGroupPos); 652 pos = readChildrenPosition(root, lastFlags, lastCandidateGroupPos); 653 break; 654 } else { 655 // Here is a little tricky part: we come here if we found out that all children 656 // addresses in this group are bigger than the address we are searching for. 657 // Should we conclude the word is not in the dictionary? No! It could still be 658 // one of the remaining chargroups in this node, so we have to keep looking in 659 // this node until we find it (or we realize it's not there either, in which 660 // case it's actually not in the dictionary). Pass the end of this group, ready 661 // to start the next one. 662 pos = skipChildrenPosAndAttributes(root, flags, pos); 663 } 664 } else { 665 // If we did not find it, we should record the last children address for the next 666 // iteration. 667 if (hasChildren) lastCandidateGroupPos = startPos; 668 // Now skip the end of this group (children pos and the attributes if any) so that 669 // our pos is after the end of this char group, at the start of the next one. 670 pos = skipChildrenPosAndAttributes(root, flags, pos); 671 } 672 673 } 674 } 675 // If we have looked through all the chargroups and found no match, the address is 676 // not the address of a terminal in this dictionary. 677 return 0; 678 } 679 680 static inline int backoff(const int unigramProbability) { 681 return unigramProbability; 682 // For some reason, applying the backoff weight gives bad results in tests. To apply the 683 // backoff weight, we divide the probability by 2, which in our storing format means 684 // decreasing the score by 8. 685 // TODO: figure out what's wrong with this. 686 // return unigramProbability > 8 ? unigramProbability - 8 : (0 == unigramProbability ? 0 : 8); 687 } 688 689 inline int BinaryFormat::computeProbabilityForBigram( 690 const int unigramProbability, const int bigramProbability) { 691 // We divide the range [unigramProbability..255] in 16.5 steps - in other words, we want the 692 // unigram probability to be the median value of the 17th step from the top. A value of 693 // 0 for the bigram probability represents the middle of the 16th step from the top, 694 // while a value of 15 represents the middle of the top step. 695 // See makedict.BinaryDictInputOutput for details. 696 const float stepSize = static_cast<float>(MAX_PROBABILITY - unigramProbability) 697 / (1.5f + MAX_BIGRAM_ENCODED_PROBABILITY); 698 return unigramProbability 699 + static_cast<int>(static_cast<float>(bigramProbability + 1) * stepSize); 700 } 701 702 // This returns a probability in log space. 703 inline int BinaryFormat::getProbability(const int position, const std::map<int, int> *bigramMap, 704 const uint8_t *bigramFilter, const int unigramProbability) { 705 if (!bigramMap || !bigramFilter) return backoff(unigramProbability); 706 if (!isInFilter(bigramFilter, position)) return backoff(unigramProbability); 707 const std::map<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position); 708 if (bigramProbabilityIt != bigramMap->end()) { 709 const int bigramProbability = bigramProbabilityIt->second; 710 return computeProbabilityForBigram(unigramProbability, bigramProbability); 711 } 712 return backoff(unigramProbability); 713 } 714 715 // This returns a probability in log space. 716 inline int BinaryFormat::getBigramProbabilityFromHashMap(const int position, 717 const hash_map_compat<int, int> *bigramMap, const int unigramProbability) { 718 if (!bigramMap) return backoff(unigramProbability); 719 const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position); 720 if (bigramProbabilityIt != bigramMap->end()) { 721 const int bigramProbability = bigramProbabilityIt->second; 722 return computeProbabilityForBigram(unigramProbability, bigramProbability); 723 } 724 return backoff(unigramProbability); 725 } 726 727 AK_FORCE_INLINE void BinaryFormat::fillBigramProbabilityToHashMap( 728 const uint8_t *const root, int position, hash_map_compat<int, int> *bigramMap) { 729 position = getBigramListPositionForWordPosition(root, position); 730 if (0 == position) return; 731 732 uint8_t bigramFlags; 733 do { 734 bigramFlags = getFlagsAndForwardPointer(root, &position); 735 const int probability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags; 736 const int bigramPos = getAttributeAddressAndForwardPointer(root, bigramFlags, 737 &position); 738 (*bigramMap)[bigramPos] = probability; 739 } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); 740 } 741 742 AK_FORCE_INLINE int BinaryFormat::getBigramProbability(const uint8_t *const root, int position, 743 const int nextPosition, const int unigramProbability) { 744 position = getBigramListPositionForWordPosition(root, position); 745 if (0 == position) return backoff(unigramProbability); 746 747 uint8_t bigramFlags; 748 do { 749 bigramFlags = getFlagsAndForwardPointer(root, &position); 750 const int bigramPos = getAttributeAddressAndForwardPointer( 751 root, bigramFlags, &position); 752 if (bigramPos == nextPosition) { 753 const int bigramProbability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags; 754 return computeProbabilityForBigram(unigramProbability, bigramProbability); 755 } 756 } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); 757 return backoff(unigramProbability); 758 } 759 760 // Returns a pointer to the start of the bigram list. 761 AK_FORCE_INLINE int BinaryFormat::getBigramListPositionForWordPosition( 762 const uint8_t *const root, int position) { 763 if (NOT_VALID_WORD == position) return 0; 764 const uint8_t flags = getFlagsAndForwardPointer(root, &position); 765 if (!(flags & FLAG_HAS_BIGRAMS)) return 0; 766 if (flags & FLAG_HAS_MULTIPLE_CHARS) { 767 position = skipOtherCharacters(root, position); 768 } else { 769 getCodePointAndForwardPointer(root, &position); 770 } 771 position = skipProbability(flags, position); 772 position = skipChildrenPosition(flags, position); 773 position = skipShortcuts(root, flags, position); 774 return position; 775 } 776 777 } // namespace latinime 778 #endif // LATINIME_BINARY_FORMAT_H 779