Home | History | Annotate | Download | only in src
      1 /*
      2  * Copyright (C) 2011 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LATINIME_BINARY_FORMAT_H
     18 #define LATINIME_BINARY_FORMAT_H
     19 
     20 #include <cstdlib>
     21 #include <map>
     22 #include <stdint.h>
     23 
     24 #include "bloom_filter.h"
     25 #include "char_utils.h"
     26 #include "hash_map_compat.h"
     27 
     28 namespace latinime {
     29 
     30 class BinaryFormat {
     31  public:
     32     // Mask and flags for children address type selection.
     33     static const int MASK_GROUP_ADDRESS_TYPE = 0xC0;
     34 
     35     // Flag for single/multiple char group
     36     static const int FLAG_HAS_MULTIPLE_CHARS = 0x20;
     37 
     38     // Flag for terminal groups
     39     static const int FLAG_IS_TERMINAL = 0x10;
     40 
     41     // Flag for shortcut targets presence
     42     static const int FLAG_HAS_SHORTCUT_TARGETS = 0x08;
     43     // Flag for bigram presence
     44     static const int FLAG_HAS_BIGRAMS = 0x04;
     45     // Flag for non-words (typically, shortcut only entries)
     46     static const int FLAG_IS_NOT_A_WORD = 0x02;
     47     // Flag for blacklist
     48     static const int FLAG_IS_BLACKLISTED = 0x01;
     49 
     50     // Attribute (bigram/shortcut) related flags:
     51     // Flag for presence of more attributes
     52     static const int FLAG_ATTRIBUTE_HAS_NEXT = 0x80;
     53     // Flag for sign of offset. If this flag is set, the offset value must be negated.
     54     static const int FLAG_ATTRIBUTE_OFFSET_NEGATIVE = 0x40;
     55 
     56     // Mask for attribute probability, stored on 4 bits inside the flags byte.
     57     static const int MASK_ATTRIBUTE_PROBABILITY = 0x0F;
     58     // The numeric value of the shortcut probability that means 'whitelist'.
     59     static const int WHITELIST_SHORTCUT_PROBABILITY = 15;
     60 
     61     // Mask and flags for attribute address type selection.
     62     static const int MASK_ATTRIBUTE_ADDRESS_TYPE = 0x30;
     63 
     64     static const int UNKNOWN_FORMAT = -1;
     65     static const int SHORTCUT_LIST_SIZE_SIZE = 2;
     66 
     67     static int detectFormat(const uint8_t *const dict, const int dictSize);
     68     static int getHeaderSize(const uint8_t *const dict, const int dictSize);
     69     static int getFlags(const uint8_t *const dict, const int dictSize);
     70     static bool hasBlacklistedOrNotAWordFlag(const int flags);
     71     static void readHeaderValue(const uint8_t *const dict, const int dictSize,
     72             const char *const key, int *outValue, const int outValueSize);
     73     static int readHeaderValueInt(const uint8_t *const dict, const int dictSize,
     74             const char *const key);
     75     static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos);
     76     static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos);
     77     static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos);
     78     static int readProbabilityWithoutMovingPointer(const uint8_t *const dict, const int pos);
     79     static int skipOtherCharacters(const uint8_t *const dict, const int pos);
     80     static int skipChildrenPosition(const uint8_t flags, const int pos);
     81     static int skipProbability(const uint8_t flags, const int pos);
     82     static int skipShortcuts(const uint8_t *const dict, const uint8_t flags, const int pos);
     83     static int skipChildrenPosAndAttributes(const uint8_t *const dict, const uint8_t flags,
     84             const int pos);
     85     static int readChildrenPosition(const uint8_t *const dict, const uint8_t flags, const int pos);
     86     static bool hasChildrenInFlags(const uint8_t flags);
     87     static int getAttributeAddressAndForwardPointer(const uint8_t *const dict, const uint8_t flags,
     88             int *pos);
     89     static int getAttributeProbabilityFromFlags(const int flags);
     90     static int getTerminalPosition(const uint8_t *const root, const int *const inWord,
     91             const int length, const bool forceLowerCaseSearch);
     92     static int getWordAtAddress(const uint8_t *const root, const int address, const int maxDepth,
     93             int *outWord, int *outUnigramProbability);
     94     static int computeProbabilityForBigram(
     95             const int unigramProbability, const int bigramProbability);
     96     static int getProbability(const int position, const std::map<int, int> *bigramMap,
     97             const uint8_t *bigramFilter, const int unigramProbability);
     98     static int getBigramProbabilityFromHashMap(const int position,
     99             const hash_map_compat<int, int> *bigramMap, const int unigramProbability);
    100     static float getMultiWordCostMultiplier(const uint8_t *const dict, const int dictSize);
    101     static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position,
    102             hash_map_compat<int, int> *bigramMap);
    103     static int getBigramProbability(const uint8_t *const root, int position,
    104             const int nextPosition, const int unigramProbability);
    105 
    106     // Flags for special processing
    107     // Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or
    108     // something very bad (like, the apocalypse) will happen. Please update both at the same time.
    109     enum {
    110         REQUIRES_GERMAN_UMLAUT_PROCESSING = 0x1,
    111         REQUIRES_FRENCH_LIGATURES_PROCESSING = 0x4
    112     };
    113 
    114  private:
    115     DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryFormat);
    116     static int getBigramListPositionForWordPosition(const uint8_t *const root, int position);
    117 
    118     static const int FLAG_GROUP_ADDRESS_TYPE_NOADDRESS = 0x00;
    119     static const int FLAG_GROUP_ADDRESS_TYPE_ONEBYTE = 0x40;
    120     static const int FLAG_GROUP_ADDRESS_TYPE_TWOBYTES = 0x80;
    121     static const int FLAG_GROUP_ADDRESS_TYPE_THREEBYTES = 0xC0;
    122     static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE = 0x10;
    123     static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20;
    124     static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30;
    125 
    126     // Any file smaller than this is not a dictionary.
    127     static const int DICTIONARY_MINIMUM_SIZE = 4;
    128     // Originally, format version 1 had a 16-bit magic number, then the version number `01'
    129     // then options that must be 0. Hence the first 32-bits of the format are always as follow
    130     // and it's okay to consider them a magic number as a whole.
    131     static const int FORMAT_VERSION_1_MAGIC_NUMBER = 0x78B10100;
    132     static const int FORMAT_VERSION_1_HEADER_SIZE = 5;
    133     // The versions of Latin IME that only handle format version 1 only test for the magic
    134     // number, so we had to change it so that version 2 files would be rejected by older
    135     // implementations. On this occasion, we made the magic number 32 bits long.
    136     static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE
    137     // Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12
    138     static const int FORMAT_VERSION_2_MINIMUM_SIZE = 12;
    139 
    140     static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1;
    141     static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20;
    142     static const int CHARACTER_ARRAY_TERMINATOR = 0x1F;
    143     static const int MULTIPLE_BYTE_CHARACTER_ADDITIONAL_SIZE = 2;
    144     static const int NO_FLAGS = 0;
    145     static int skipAllAttributes(const uint8_t *const dict, const uint8_t flags, const int pos);
    146     static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos);
    147 };
    148 
    149 AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict, const int dictSize) {
    150     // The magic number is stored big-endian.
    151     // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't
    152     // understand this format.
    153     if (dictSize < DICTIONARY_MINIMUM_SIZE) return UNKNOWN_FORMAT;
    154     const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3];
    155     switch (magicNumber) {
    156     case FORMAT_VERSION_1_MAGIC_NUMBER:
    157         // Format 1 header is exactly 5 bytes long and looks like:
    158         // Magic number (2 bytes) 0x78 0xB1
    159         // Version number (1 byte) 0x01
    160         // Options (2 bytes) must be 0x00 0x00
    161         return 1;
    162     case FORMAT_VERSION_2_MAGIC_NUMBER:
    163         // Version 2 dictionaries are at least 12 bytes long (see below details for the header).
    164         // If this dictionary has the version 2 magic number but is less than 12 bytes long, then
    165         // it's an unknown format and we need to avoid confidently reading the next bytes.
    166         if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) return UNKNOWN_FORMAT;
    167         // Format 2 header is as follows:
    168         // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE
    169         // Version number (2 bytes) 0x00 0x02
    170         // Options (2 bytes)
    171         // Header size (4 bytes) : integer, big endian
    172         return (dict[4] << 8) + dict[5];
    173     default:
    174         return UNKNOWN_FORMAT;
    175     }
    176 }
    177 
    178 inline int BinaryFormat::getFlags(const uint8_t *const dict, const int dictSize) {
    179     switch (detectFormat(dict, dictSize)) {
    180     case 1:
    181         return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else?
    182     default:
    183         return (dict[6] << 8) + dict[7];
    184     }
    185 }
    186 
    187 inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) {
    188     return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0;
    189 }
    190 
    191 inline int BinaryFormat::getHeaderSize(const uint8_t *const dict, const int dictSize) {
    192     switch (detectFormat(dict, dictSize)) {
    193     case 1:
    194         return FORMAT_VERSION_1_HEADER_SIZE;
    195     case 2:
    196         // See the format of the header in the comment in detectFormat() above
    197         return (dict[8] << 24) + (dict[9] << 16) + (dict[10] << 8) + dict[11];
    198     default:
    199         return S_INT_MAX;
    200     }
    201 }
    202 
    203 inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const int dictSize,
    204         const char *const key, int *outValue, const int outValueSize) {
    205     int outValueIndex = 0;
    206     // Only format 2 and above have header attributes as {key,value} string pairs. For prior
    207     // formats, we just return an empty string, as if the key wasn't found.
    208     if (2 <= detectFormat(dict, dictSize)) {
    209         const int headerOptionsOffset = 4 /* magic number */
    210                 + 2 /* dictionary version */ + 2 /* flags */;
    211         const int headerSize =
    212                 (dict[headerOptionsOffset] << 24) + (dict[headerOptionsOffset + 1] << 16)
    213                 + (dict[headerOptionsOffset + 2] << 8) + dict[headerOptionsOffset + 3];
    214         const int headerEnd = headerOptionsOffset + 4 + headerSize;
    215         int index = headerOptionsOffset + 4;
    216         while (index < headerEnd) {
    217             int keyIndex = 0;
    218             int codePoint = getCodePointAndForwardPointer(dict, &index);
    219             while (codePoint != NOT_A_CODE_POINT) {
    220                 if (codePoint != key[keyIndex++]) {
    221                     break;
    222                 }
    223                 codePoint = getCodePointAndForwardPointer(dict, &index);
    224             }
    225             if (codePoint == NOT_A_CODE_POINT && key[keyIndex] == 0) {
    226                 // We found the key! Copy and return the value.
    227                 codePoint = getCodePointAndForwardPointer(dict, &index);
    228                 while (codePoint != NOT_A_CODE_POINT && outValueIndex < outValueSize) {
    229                     outValue[outValueIndex++] = codePoint;
    230                     codePoint = getCodePointAndForwardPointer(dict, &index);
    231                 }
    232                 // Finished copying. Break to go to the termination code.
    233                 break;
    234             }
    235             // We didn't find the key, skip the remainder of it and its value
    236             while (codePoint != NOT_A_CODE_POINT) {
    237                 codePoint = getCodePointAndForwardPointer(dict, &index);
    238             }
    239             codePoint = getCodePointAndForwardPointer(dict, &index);
    240             while (codePoint != NOT_A_CODE_POINT) {
    241                 codePoint = getCodePointAndForwardPointer(dict, &index);
    242             }
    243         }
    244         // We couldn't find it - fall through and return an empty value.
    245     }
    246     // Put a terminator 0 if possible at all (always unless outValueSize is <= 0)
    247     if (outValueIndex >= outValueSize) outValueIndex = outValueSize - 1;
    248     if (outValueIndex >= 0) outValue[outValueIndex] = 0;
    249 }
    250 
    251 inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const int dictSize,
    252         const char *const key) {
    253     const int bufferSize = LARGEST_INT_DIGIT_COUNT;
    254     int intBuffer[bufferSize];
    255     char charBuffer[bufferSize];
    256     BinaryFormat::readHeaderValue(dict, dictSize, key, intBuffer, bufferSize);
    257     for (int i = 0; i < bufferSize; ++i) {
    258         charBuffer[i] = intBuffer[i];
    259     }
    260     // If not a number, return S_INT_MIN
    261     if (!isdigit(charBuffer[0])) return S_INT_MIN;
    262     return atoi(charBuffer);
    263 }
    264 
    265 AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t *const dict,
    266         int *pos) {
    267     const int msb = dict[(*pos)++];
    268     if (msb < 0x80) return msb;
    269     return ((msb & 0x7F) << 8) | dict[(*pos)++];
    270 }
    271 
    272 inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict,
    273         const int dictSize) {
    274     const int headerValue = readHeaderValueInt(dict, dictSize,
    275             "MULTIPLE_WORDS_DEMOTION_RATE");
    276     if (headerValue == S_INT_MIN) {
    277         return 1.0f;
    278     }
    279     if (headerValue <= 0) {
    280         return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
    281     }
    282     return 100.0f / static_cast<float>(headerValue);
    283 }
    284 
    285 inline uint8_t BinaryFormat::getFlagsAndForwardPointer(const uint8_t *const dict, int *pos) {
    286     return dict[(*pos)++];
    287 }
    288 
    289 AK_FORCE_INLINE int BinaryFormat::getCodePointAndForwardPointer(const uint8_t *const dict,
    290         int *pos) {
    291     const int origin = *pos;
    292     const int codePoint = dict[origin];
    293     if (codePoint < MINIMAL_ONE_BYTE_CHARACTER_VALUE) {
    294         if (codePoint == CHARACTER_ARRAY_TERMINATOR) {
    295             *pos = origin + 1;
    296             return NOT_A_CODE_POINT;
    297         } else {
    298             *pos = origin + 3;
    299             const int char_1 = codePoint << 16;
    300             const int char_2 = char_1 + (dict[origin + 1] << 8);
    301             return char_2 + dict[origin + 2];
    302         }
    303     } else {
    304         *pos = origin + 1;
    305         return codePoint;
    306     }
    307 }
    308 
    309 inline int BinaryFormat::readProbabilityWithoutMovingPointer(const uint8_t *const dict,
    310         const int pos) {
    311     return dict[pos];
    312 }
    313 
    314 AK_FORCE_INLINE int BinaryFormat::skipOtherCharacters(const uint8_t *const dict, const int pos) {
    315     int currentPos = pos;
    316     int character = dict[currentPos++];
    317     while (CHARACTER_ARRAY_TERMINATOR != character) {
    318         if (character < MINIMAL_ONE_BYTE_CHARACTER_VALUE) {
    319             currentPos += MULTIPLE_BYTE_CHARACTER_ADDITIONAL_SIZE;
    320         }
    321         character = dict[currentPos++];
    322     }
    323     return currentPos;
    324 }
    325 
    326 static inline int attributeAddressSize(const uint8_t flags) {
    327     static const int ATTRIBUTE_ADDRESS_SHIFT = 4;
    328     return (flags & BinaryFormat::MASK_ATTRIBUTE_ADDRESS_TYPE) >> ATTRIBUTE_ADDRESS_SHIFT;
    329     /* Note: this is a value-dependant optimization of what may probably be
    330        more readably written this way:
    331        switch (flags * BinaryFormat::MASK_ATTRIBUTE_ADDRESS_TYPE) {
    332        case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: return 1;
    333        case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: return 2;
    334        case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTE: return 3;
    335        default: return 0;
    336        }
    337     */
    338 }
    339 
    340 static AK_FORCE_INLINE int skipExistingBigrams(const uint8_t *const dict, const int pos) {
    341     int currentPos = pos;
    342     uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dict, &currentPos);
    343     while (flags & BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT) {
    344         currentPos += attributeAddressSize(flags);
    345         flags = BinaryFormat::getFlagsAndForwardPointer(dict, &currentPos);
    346     }
    347     currentPos += attributeAddressSize(flags);
    348     return currentPos;
    349 }
    350 
    351 static inline int childrenAddressSize(const uint8_t flags) {
    352     static const int CHILDREN_ADDRESS_SHIFT = 6;
    353     return (BinaryFormat::MASK_GROUP_ADDRESS_TYPE & flags) >> CHILDREN_ADDRESS_SHIFT;
    354     /* See the note in attributeAddressSize. The same applies here */
    355 }
    356 
    357 static AK_FORCE_INLINE int shortcutByteSize(const uint8_t *const dict, const int pos) {
    358     return (static_cast<int>(dict[pos] << 8)) + (dict[pos + 1]);
    359 }
    360 
    361 inline int BinaryFormat::skipChildrenPosition(const uint8_t flags, const int pos) {
    362     return pos + childrenAddressSize(flags);
    363 }
    364 
    365 inline int BinaryFormat::skipProbability(const uint8_t flags, const int pos) {
    366     return FLAG_IS_TERMINAL & flags ? pos + 1 : pos;
    367 }
    368 
    369 AK_FORCE_INLINE int BinaryFormat::skipShortcuts(const uint8_t *const dict, const uint8_t flags,
    370         const int pos) {
    371     if (FLAG_HAS_SHORTCUT_TARGETS & flags) {
    372         return pos + shortcutByteSize(dict, pos);
    373     } else {
    374         return pos;
    375     }
    376 }
    377 
    378 AK_FORCE_INLINE int BinaryFormat::skipBigrams(const uint8_t *const dict, const uint8_t flags,
    379         const int pos) {
    380     if (FLAG_HAS_BIGRAMS & flags) {
    381         return skipExistingBigrams(dict, pos);
    382     } else {
    383         return pos;
    384     }
    385 }
    386 
    387 AK_FORCE_INLINE int BinaryFormat::skipAllAttributes(const uint8_t *const dict, const uint8_t flags,
    388         const int pos) {
    389     // This function skips all attributes: shortcuts and bigrams.
    390     int newPos = pos;
    391     newPos = skipShortcuts(dict, flags, newPos);
    392     newPos = skipBigrams(dict, flags, newPos);
    393     return newPos;
    394 }
    395 
    396 AK_FORCE_INLINE int BinaryFormat::skipChildrenPosAndAttributes(const uint8_t *const dict,
    397         const uint8_t flags, const int pos) {
    398     int currentPos = pos;
    399     currentPos = skipChildrenPosition(flags, currentPos);
    400     currentPos = skipAllAttributes(dict, flags, currentPos);
    401     return currentPos;
    402 }
    403 
    404 AK_FORCE_INLINE int BinaryFormat::readChildrenPosition(const uint8_t *const dict,
    405         const uint8_t flags, const int pos) {
    406     int offset = 0;
    407     switch (MASK_GROUP_ADDRESS_TYPE & flags) {
    408         case FLAG_GROUP_ADDRESS_TYPE_ONEBYTE:
    409             offset = dict[pos];
    410             break;
    411         case FLAG_GROUP_ADDRESS_TYPE_TWOBYTES:
    412             offset = dict[pos] << 8;
    413             offset += dict[pos + 1];
    414             break;
    415         case FLAG_GROUP_ADDRESS_TYPE_THREEBYTES:
    416             offset = dict[pos] << 16;
    417             offset += dict[pos + 1] << 8;
    418             offset += dict[pos + 2];
    419             break;
    420         default:
    421             // If we come here, it means we asked for the children of a word with
    422             // no children.
    423             return -1;
    424     }
    425     return pos + offset;
    426 }
    427 
    428 inline bool BinaryFormat::hasChildrenInFlags(const uint8_t flags) {
    429     return (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS != (MASK_GROUP_ADDRESS_TYPE & flags));
    430 }
    431 
    432 AK_FORCE_INLINE int BinaryFormat::getAttributeAddressAndForwardPointer(const uint8_t *const dict,
    433         const uint8_t flags, int *pos) {
    434     int offset = 0;
    435     const int origin = *pos;
    436     switch (MASK_ATTRIBUTE_ADDRESS_TYPE & flags) {
    437         case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE:
    438             offset = dict[origin];
    439             *pos = origin + 1;
    440             break;
    441         case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES:
    442             offset = dict[origin] << 8;
    443             offset += dict[origin + 1];
    444             *pos = origin + 2;
    445             break;
    446         case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES:
    447             offset = dict[origin] << 16;
    448             offset += dict[origin + 1] << 8;
    449             offset += dict[origin + 2];
    450             *pos = origin + 3;
    451             break;
    452     }
    453     if (FLAG_ATTRIBUTE_OFFSET_NEGATIVE & flags) {
    454         return origin - offset;
    455     } else {
    456         return origin + offset;
    457     }
    458 }
    459 
    460 inline int BinaryFormat::getAttributeProbabilityFromFlags(const int flags) {
    461     return flags & MASK_ATTRIBUTE_PROBABILITY;
    462 }
    463 
    464 // This function gets the byte position of the last chargroup of the exact matching word in the
    465 // dictionary. If no match is found, it returns NOT_VALID_WORD.
    466 AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root,
    467         const int *const inWord, const int length, const bool forceLowerCaseSearch) {
    468     int pos = 0;
    469     int wordPos = 0;
    470 
    471     while (true) {
    472         // If we already traversed the tree further than the word is long, there means
    473         // there was no match (or we would have found it).
    474         if (wordPos >= length) return NOT_VALID_WORD;
    475         int charGroupCount = BinaryFormat::getGroupCountAndForwardPointer(root, &pos);
    476         const int wChar = forceLowerCaseSearch ? toLowerCase(inWord[wordPos]) : inWord[wordPos];
    477         while (true) {
    478             // If there are no more character groups in this node, it means we could not
    479             // find a matching character for this depth, therefore there is no match.
    480             if (0 >= charGroupCount) return NOT_VALID_WORD;
    481             const int charGroupPos = pos;
    482             const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos);
    483             int character = BinaryFormat::getCodePointAndForwardPointer(root, &pos);
    484             if (character == wChar) {
    485                 // This is the correct node. Only one character group may start with the same
    486                 // char within a node, so either we found our match in this node, or there is
    487                 // no match and we can return NOT_VALID_WORD. So we will check all the characters
    488                 // in this character group indeed does match.
    489                 if (FLAG_HAS_MULTIPLE_CHARS & flags) {
    490                     character = BinaryFormat::getCodePointAndForwardPointer(root, &pos);
    491                     while (NOT_A_CODE_POINT != character) {
    492                         ++wordPos;
    493                         // If we shoot the length of the word we search for, or if we find a single
    494                         // character that does not match, as explained above, it means the word is
    495                         // not in the dictionary (by virtue of this chargroup being the only one to
    496                         // match the word on the first character, but not matching the whole word).
    497                         if (wordPos >= length) return NOT_VALID_WORD;
    498                         if (inWord[wordPos] != character) return NOT_VALID_WORD;
    499                         character = BinaryFormat::getCodePointAndForwardPointer(root, &pos);
    500                     }
    501                 }
    502                 // If we come here we know that so far, we do match. Either we are on a terminal
    503                 // and we match the length, in which case we found it, or we traverse children.
    504                 // If we don't match the length AND don't have children, then a word in the
    505                 // dictionary fully matches a prefix of the searched word but not the full word.
    506                 ++wordPos;
    507                 if (FLAG_IS_TERMINAL & flags) {
    508                     if (wordPos == length) {
    509                         return charGroupPos;
    510                     }
    511                     pos = BinaryFormat::skipProbability(FLAG_IS_TERMINAL, pos);
    512                 }
    513                 if (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS == (MASK_GROUP_ADDRESS_TYPE & flags)) {
    514                     return NOT_VALID_WORD;
    515                 }
    516                 // We have children and we are still shorter than the word we are searching for, so
    517                 // we need to traverse children. Put the pointer on the children position, and
    518                 // break
    519                 pos = BinaryFormat::readChildrenPosition(root, flags, pos);
    520                 break;
    521             } else {
    522                 // This chargroup does not match, so skip the remaining part and go to the next.
    523                 if (FLAG_HAS_MULTIPLE_CHARS & flags) {
    524                     pos = BinaryFormat::skipOtherCharacters(root, pos);
    525                 }
    526                 pos = BinaryFormat::skipProbability(flags, pos);
    527                 pos = BinaryFormat::skipChildrenPosAndAttributes(root, flags, pos);
    528             }
    529             --charGroupCount;
    530         }
    531     }
    532 }
    533 
    534 // This function searches for a terminal in the dictionary by its address.
    535 // Due to the fact that words are ordered in the dictionary in a strict breadth-first order,
    536 // it is possible to check for this with advantageous complexity. For each node, we search
    537 // for groups with children and compare the children address with the address we look for.
    538 // When we shoot the address we look for, it means the word we look for is in the children
    539 // of the previous group. The only tricky part is the fact that if we arrive at the end of a
    540 // node with the last group's children address still less than what we are searching for, we
    541 // must descend the last group's children (for example, if the word we are searching for starts
    542 // with a z, it's the last group of the root node, so all children addresses will be smaller
    543 // than the address we look for, and we have to descend the z node).
    544 /* Parameters :
    545  * root: the dictionary buffer
    546  * address: the byte position of the last chargroup of the word we are searching for (this is
    547  *   what is stored as the "bigram address" in each bigram)
    548  * outword: an array to write the found word, with MAX_WORD_LENGTH size.
    549  * outUnigramProbability: a pointer to an int to write the probability into.
    550  * Return value : the length of the word, of 0 if the word was not found.
    551  */
    552 AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, const int address,
    553         const int maxDepth, int *outWord, int *outUnigramProbability) {
    554     int pos = 0;
    555     int wordPos = 0;
    556 
    557     // One iteration of the outer loop iterates through nodes. As stated above, we will only
    558     // traverse nodes that are actually a part of the terminal we are searching, so each time
    559     // we enter this loop we are one depth level further than last time.
    560     // The only reason we count nodes is because we want to reduce the probability of infinite
    561     // looping in case there is a bug. Since we know there is an upper bound to the depth we are
    562     // supposed to traverse, it does not hurt to count iterations.
    563     for (int loopCount = maxDepth; loopCount > 0; --loopCount) {
    564         int lastCandidateGroupPos = 0;
    565         // Let's loop through char groups in this node searching for either the terminal
    566         // or one of its ascendants.
    567         for (int charGroupCount = getGroupCountAndForwardPointer(root, &pos); charGroupCount > 0;
    568                  --charGroupCount) {
    569             const int startPos = pos;
    570             const uint8_t flags = getFlagsAndForwardPointer(root, &pos);
    571             const int character = getCodePointAndForwardPointer(root, &pos);
    572             if (address == startPos) {
    573                 // We found the address. Copy the rest of the word in the buffer and return
    574                 // the length.
    575                 outWord[wordPos] = character;
    576                 if (FLAG_HAS_MULTIPLE_CHARS & flags) {
    577                     int nextChar = getCodePointAndForwardPointer(root, &pos);
    578                     // We count chars in order to avoid infinite loops if the file is broken or
    579                     // if there is some other bug
    580                     int charCount = maxDepth;
    581                     while (NOT_A_CODE_POINT != nextChar && --charCount > 0) {
    582                         outWord[++wordPos] = nextChar;
    583                         nextChar = getCodePointAndForwardPointer(root, &pos);
    584                     }
    585                 }
    586                 *outUnigramProbability = readProbabilityWithoutMovingPointer(root, pos);
    587                 return ++wordPos;
    588             }
    589             // We need to skip past this char group, so skip any remaining chars after the
    590             // first and possibly the probability.
    591             if (FLAG_HAS_MULTIPLE_CHARS & flags) {
    592                 pos = skipOtherCharacters(root, pos);
    593             }
    594             pos = skipProbability(flags, pos);
    595 
    596             // The fact that this group has children is very important. Since we already know
    597             // that this group does not match, if it has no children we know it is irrelevant
    598             // to what we are searching for.
    599             const bool hasChildren = (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS !=
    600                     (MASK_GROUP_ADDRESS_TYPE & flags));
    601             // We will write in `found' whether we have passed the children address we are
    602             // searching for. For example if we search for "beer", the children of b are less
    603             // than the address we are searching for and the children of c are greater. When we
    604             // come here for c, we realize this is too big, and that we should descend b.
    605             bool found;
    606             if (hasChildren) {
    607                 // Here comes the tricky part. First, read the children position.
    608                 const int childrenPos = readChildrenPosition(root, flags, pos);
    609                 if (childrenPos > address) {
    610                     // If the children pos is greater than address, it means the previous chargroup,
    611                     // which address is stored in lastCandidateGroupPos, was the right one.
    612                     found = true;
    613                 } else if (1 >= charGroupCount) {
    614                     // However if we are on the LAST group of this node, and we have NOT shot the
    615                     // address we should descend THIS node. So we trick the lastCandidateGroupPos
    616                     // so that we will descend this node, not the previous one.
    617                     lastCandidateGroupPos = startPos;
    618                     found = true;
    619                 } else {
    620                     // Else, we should continue looking.
    621                     found = false;
    622                 }
    623             } else {
    624                 // Even if we don't have children here, we could still be on the last group of this
    625                 // node. If this is the case, we should descend the last group that had children,
    626                 // and their address is already in lastCandidateGroup.
    627                 found = (1 >= charGroupCount);
    628             }
    629 
    630             if (found) {
    631                 // Okay, we found the group we should descend. Its address is in
    632                 // the lastCandidateGroupPos variable, so we just re-read it.
    633                 if (0 != lastCandidateGroupPos) {
    634                     const uint8_t lastFlags =
    635                             getFlagsAndForwardPointer(root, &lastCandidateGroupPos);
    636                     const int lastChar =
    637                             getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
    638                     // We copy all the characters in this group to the buffer
    639                     outWord[wordPos] = lastChar;
    640                     if (FLAG_HAS_MULTIPLE_CHARS & lastFlags) {
    641                         int nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
    642                         int charCount = maxDepth;
    643                         while (-1 != nextChar && --charCount > 0) {
    644                             outWord[++wordPos] = nextChar;
    645                             nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
    646                         }
    647                     }
    648                     ++wordPos;
    649                     // Now we only need to branch to the children address. Skip the probability if
    650                     // it's there, read pos, and break to resume the search at pos.
    651                     lastCandidateGroupPos = skipProbability(lastFlags, lastCandidateGroupPos);
    652                     pos = readChildrenPosition(root, lastFlags, lastCandidateGroupPos);
    653                     break;
    654                 } else {
    655                     // Here is a little tricky part: we come here if we found out that all children
    656                     // addresses in this group are bigger than the address we are searching for.
    657                     // Should we conclude the word is not in the dictionary? No! It could still be
    658                     // one of the remaining chargroups in this node, so we have to keep looking in
    659                     // this node until we find it (or we realize it's not there either, in which
    660                     // case it's actually not in the dictionary). Pass the end of this group, ready
    661                     // to start the next one.
    662                     pos = skipChildrenPosAndAttributes(root, flags, pos);
    663                 }
    664             } else {
    665                 // If we did not find it, we should record the last children address for the next
    666                 // iteration.
    667                 if (hasChildren) lastCandidateGroupPos = startPos;
    668                 // Now skip the end of this group (children pos and the attributes if any) so that
    669                 // our pos is after the end of this char group, at the start of the next one.
    670                 pos = skipChildrenPosAndAttributes(root, flags, pos);
    671             }
    672 
    673         }
    674     }
    675     // If we have looked through all the chargroups and found no match, the address is
    676     // not the address of a terminal in this dictionary.
    677     return 0;
    678 }
    679 
    680 static inline int backoff(const int unigramProbability) {
    681     return unigramProbability;
    682     // For some reason, applying the backoff weight gives bad results in tests. To apply the
    683     // backoff weight, we divide the probability by 2, which in our storing format means
    684     // decreasing the score by 8.
    685     // TODO: figure out what's wrong with this.
    686     // return unigramProbability > 8 ? unigramProbability - 8 : (0 == unigramProbability ? 0 : 8);
    687 }
    688 
    689 inline int BinaryFormat::computeProbabilityForBigram(
    690         const int unigramProbability, const int bigramProbability) {
    691     // We divide the range [unigramProbability..255] in 16.5 steps - in other words, we want the
    692     // unigram probability to be the median value of the 17th step from the top. A value of
    693     // 0 for the bigram probability represents the middle of the 16th step from the top,
    694     // while a value of 15 represents the middle of the top step.
    695     // See makedict.BinaryDictInputOutput for details.
    696     const float stepSize = static_cast<float>(MAX_PROBABILITY - unigramProbability)
    697             / (1.5f + MAX_BIGRAM_ENCODED_PROBABILITY);
    698     return unigramProbability
    699             + static_cast<int>(static_cast<float>(bigramProbability + 1) * stepSize);
    700 }
    701 
    702 // This returns a probability in log space.
    703 inline int BinaryFormat::getProbability(const int position, const std::map<int, int> *bigramMap,
    704         const uint8_t *bigramFilter, const int unigramProbability) {
    705     if (!bigramMap || !bigramFilter) return backoff(unigramProbability);
    706     if (!isInFilter(bigramFilter, position)) return backoff(unigramProbability);
    707     const std::map<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position);
    708     if (bigramProbabilityIt != bigramMap->end()) {
    709         const int bigramProbability = bigramProbabilityIt->second;
    710         return computeProbabilityForBigram(unigramProbability, bigramProbability);
    711     }
    712     return backoff(unigramProbability);
    713 }
    714 
    715 // This returns a probability in log space.
    716 inline int BinaryFormat::getBigramProbabilityFromHashMap(const int position,
    717         const hash_map_compat<int, int> *bigramMap, const int unigramProbability) {
    718     if (!bigramMap) return backoff(unigramProbability);
    719     const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position);
    720     if (bigramProbabilityIt != bigramMap->end()) {
    721         const int bigramProbability = bigramProbabilityIt->second;
    722         return computeProbabilityForBigram(unigramProbability, bigramProbability);
    723     }
    724     return backoff(unigramProbability);
    725 }
    726 
    727 AK_FORCE_INLINE void BinaryFormat::fillBigramProbabilityToHashMap(
    728         const uint8_t *const root, int position, hash_map_compat<int, int> *bigramMap) {
    729     position = getBigramListPositionForWordPosition(root, position);
    730     if (0 == position) return;
    731 
    732     uint8_t bigramFlags;
    733     do {
    734         bigramFlags = getFlagsAndForwardPointer(root, &position);
    735         const int probability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags;
    736         const int bigramPos = getAttributeAddressAndForwardPointer(root, bigramFlags,
    737                 &position);
    738         (*bigramMap)[bigramPos] = probability;
    739     } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags);
    740 }
    741 
    742 AK_FORCE_INLINE int BinaryFormat::getBigramProbability(const uint8_t *const root, int position,
    743         const int nextPosition, const int unigramProbability) {
    744     position = getBigramListPositionForWordPosition(root, position);
    745     if (0 == position) return backoff(unigramProbability);
    746 
    747     uint8_t bigramFlags;
    748     do {
    749         bigramFlags = getFlagsAndForwardPointer(root, &position);
    750         const int bigramPos = getAttributeAddressAndForwardPointer(
    751                 root, bigramFlags, &position);
    752         if (bigramPos == nextPosition) {
    753             const int bigramProbability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags;
    754             return computeProbabilityForBigram(unigramProbability, bigramProbability);
    755         }
    756     } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags);
    757     return backoff(unigramProbability);
    758 }
    759 
    760 // Returns a pointer to the start of the bigram list.
    761 AK_FORCE_INLINE int BinaryFormat::getBigramListPositionForWordPosition(
    762         const uint8_t *const root, int position) {
    763     if (NOT_VALID_WORD == position) return 0;
    764     const uint8_t flags = getFlagsAndForwardPointer(root, &position);
    765     if (!(flags & FLAG_HAS_BIGRAMS)) return 0;
    766     if (flags & FLAG_HAS_MULTIPLE_CHARS) {
    767         position = skipOtherCharacters(root, position);
    768     } else {
    769         getCodePointAndForwardPointer(root, &position);
    770     }
    771     position = skipProbability(flags, position);
    772     position = skipChildrenPosition(flags, position);
    773     position = skipShortcuts(root, flags, position);
    774     return position;
    775 }
    776 
    777 } // namespace latinime
    778 #endif // LATINIME_BINARY_FORMAT_H
    779