Home | History | Annotate | Download | only in share
      1 /*
      2  * Copyright (C) 2009 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include <assert.h>
     18 #include <stdlib.h>
     19 #include <string.h>
     20 #include "../include/dictlist.h"
     21 #include "../include/mystdlib.h"
     22 #include "../include/ngram.h"
     23 #include "../include/searchutility.h"
     24 
     25 namespace ime_pinyin {
     26 
     27 DictList::DictList() {
     28   initialized_ = false;
     29   scis_num_ = 0;
     30   scis_hz_ = NULL;
     31   scis_splid_ = NULL;
     32   buf_ = NULL;
     33   spl_trie_ = SpellingTrie::get_cpinstance();
     34 
     35   assert(kMaxLemmaSize == 8);
     36   cmp_func_[0] = cmp_hanzis_1;
     37   cmp_func_[1] = cmp_hanzis_2;
     38   cmp_func_[2] = cmp_hanzis_3;
     39   cmp_func_[3] = cmp_hanzis_4;
     40   cmp_func_[4] = cmp_hanzis_5;
     41   cmp_func_[5] = cmp_hanzis_6;
     42   cmp_func_[6] = cmp_hanzis_7;
     43   cmp_func_[7] = cmp_hanzis_8;
     44 }
     45 
     46 DictList::~DictList() {
     47   free_resource();
     48 }
     49 
     50 bool DictList::alloc_resource(size_t buf_size, size_t scis_num) {
     51   // Allocate memory
     52   buf_ = static_cast<char16*>(malloc(buf_size * sizeof(char16)));
     53   if (NULL == buf_)
     54     return false;
     55 
     56   scis_num_ = scis_num;
     57 
     58   scis_hz_ = static_cast<char16*>(malloc(scis_num_ * sizeof(char16)));
     59   if (NULL == scis_hz_)
     60     return false;
     61 
     62   scis_splid_ = static_cast<SpellingId*>
     63       (malloc(scis_num_ * sizeof(SpellingId)));
     64 
     65   if (NULL == scis_splid_)
     66     return false;
     67 
     68   return true;
     69 }
     70 
     71 void DictList::free_resource() {
     72   if (NULL != buf_)
     73     free(buf_);
     74   buf_ = NULL;
     75 
     76   if (NULL != scis_hz_)
     77     free(scis_hz_);
     78   scis_hz_ = NULL;
     79 
     80   if (NULL != scis_splid_)
     81     free(scis_splid_);
     82   scis_splid_ = NULL;
     83 }
     84 
     85 #ifdef ___BUILD_MODEL___
     86 bool DictList::init_list(const SingleCharItem *scis, size_t scis_num,
     87                          const LemmaEntry *lemma_arr, size_t lemma_num) {
     88   if (NULL == scis || 0 == scis_num || NULL == lemma_arr || 0 == lemma_num)
     89     return false;
     90 
     91   initialized_ = false;
     92 
     93   if (NULL != buf_)
     94     free(buf_);
     95 
     96   // calculate the size
     97   size_t buf_size = calculate_size(lemma_arr, lemma_num);
     98   if (0 == buf_size)
     99     return false;
    100 
    101   if (!alloc_resource(buf_size, scis_num))
    102     return false;
    103 
    104   fill_scis(scis, scis_num);
    105 
    106   // Copy the related content from the array to inner buffer
    107   fill_list(lemma_arr, lemma_num);
    108 
    109   initialized_ = true;
    110   return true;
    111 }
    112 
    113 size_t DictList::calculate_size(const LemmaEntry* lemma_arr, size_t lemma_num) {
    114   size_t last_hz_len = 0;
    115   size_t list_size = 0;
    116   size_t id_num = 0;
    117 
    118   for (size_t i = 0; i < lemma_num; i++) {
    119     if (0 == i) {
    120       last_hz_len = lemma_arr[i].hz_str_len;
    121 
    122       assert(last_hz_len > 0);
    123       assert(lemma_arr[0].idx_by_hz == 1);
    124 
    125       id_num++;
    126       start_pos_[0] = 0;
    127       start_id_[0] = id_num;
    128 
    129       last_hz_len = 1;
    130       list_size += last_hz_len;
    131     } else {
    132       size_t current_hz_len = lemma_arr[i].hz_str_len;
    133 
    134       assert(current_hz_len >= last_hz_len);
    135 
    136       if (current_hz_len == last_hz_len) {
    137           list_size += current_hz_len;
    138           id_num++;
    139       } else {
    140         for (size_t len = last_hz_len; len < current_hz_len - 1; len++) {
    141           start_pos_[len] = start_pos_[len - 1];
    142           start_id_[len] = start_id_[len - 1];
    143         }
    144 
    145         start_pos_[current_hz_len - 1] = list_size;
    146 
    147         id_num++;
    148         start_id_[current_hz_len - 1] = id_num;
    149 
    150         last_hz_len = current_hz_len;
    151         list_size += current_hz_len;
    152       }
    153     }
    154   }
    155 
    156   for (size_t i = last_hz_len; i <= kMaxLemmaSize; i++) {
    157     if (0 == i) {
    158       start_pos_[0] = 0;
    159       start_id_[0] = 1;
    160     } else {
    161       start_pos_[i] = list_size;
    162       start_id_[i] = id_num;
    163     }
    164   }
    165 
    166   return start_pos_[kMaxLemmaSize];
    167 }
    168 
    169 void DictList::fill_scis(const SingleCharItem *scis, size_t scis_num) {
    170   assert(scis_num_ == scis_num);
    171 
    172   for (size_t pos = 0; pos < scis_num_; pos++) {
    173     scis_hz_[pos] = scis[pos].hz;
    174     scis_splid_[pos] = scis[pos].splid;
    175   }
    176 }
    177 
    178 void DictList::fill_list(const LemmaEntry* lemma_arr, size_t lemma_num) {
    179   size_t current_pos = 0;
    180 
    181   utf16_strncpy(buf_, lemma_arr[0].hanzi_str,
    182                 lemma_arr[0].hz_str_len);
    183 
    184   current_pos = lemma_arr[0].hz_str_len;
    185 
    186   size_t id_num = 1;
    187 
    188   for (size_t i = 1; i < lemma_num; i++) {
    189     utf16_strncpy(buf_ + current_pos, lemma_arr[i].hanzi_str,
    190                   lemma_arr[i].hz_str_len);
    191 
    192     id_num++;
    193     current_pos += lemma_arr[i].hz_str_len;
    194   }
    195 
    196   assert(current_pos == start_pos_[kMaxLemmaSize]);
    197   assert(id_num == start_id_[kMaxLemmaSize]);
    198 }
    199 
    200 char16* DictList::find_pos2_startedbyhz(char16 hz_char) {
    201   char16 *found_2w = static_cast<char16*>
    202                      (mybsearch(&hz_char, buf_ + start_pos_[1],
    203                                 (start_pos_[2] - start_pos_[1]) / 2,
    204                                 sizeof(char16) * 2, cmp_hanzis_1));
    205   if (NULL == found_2w)
    206     return NULL;
    207 
    208   while (found_2w > buf_ + start_pos_[1] && *found_2w == *(found_2w - 1))
    209     found_2w -= 2;
    210 
    211   return found_2w;
    212 }
    213 #endif  // ___BUILD_MODEL___
    214 
    215 char16* DictList::find_pos_startedbyhzs(const char16 last_hzs[],
    216     size_t word_len, int (*cmp_func)(const void *, const void *)) {
    217   char16 *found_w = static_cast<char16*>
    218                     (mybsearch(last_hzs, buf_ + start_pos_[word_len - 1],
    219                                (start_pos_[word_len] - start_pos_[word_len - 1])
    220                                / word_len,
    221                                sizeof(char16) * word_len, cmp_func));
    222 
    223   if (NULL == found_w)
    224     return NULL;
    225 
    226   while (found_w > buf_ + start_pos_[word_len -1] &&
    227          cmp_func(found_w, found_w - word_len) == 0)
    228     found_w -= word_len;
    229 
    230   return found_w;
    231 }
    232 
    233 size_t DictList::predict(const char16 last_hzs[], uint16 hzs_len,
    234                          NPredictItem *npre_items, size_t npre_max,
    235                          size_t b4_used) {
    236   assert(hzs_len <= kMaxPredictSize && hzs_len > 0);
    237 
    238   // 1. Prepare work
    239   int (*cmp_func)(const void *, const void *) = cmp_func_[hzs_len - 1];
    240 
    241   NGram& ngram = NGram::get_instance();
    242 
    243   size_t item_num = 0;
    244 
    245   // 2. Do prediction
    246   for (uint16 pre_len = 1; pre_len <= kMaxPredictSize + 1 - hzs_len;
    247        pre_len++) {
    248     uint16 word_len = hzs_len + pre_len;
    249     char16 *w_buf = find_pos_startedbyhzs(last_hzs, word_len, cmp_func);
    250     if (NULL == w_buf)
    251       continue;
    252     while (w_buf < buf_ + start_pos_[word_len] &&
    253            cmp_func(w_buf, last_hzs) == 0 &&
    254            item_num < npre_max) {
    255       memset(npre_items + item_num, 0, sizeof(NPredictItem));
    256       utf16_strncpy(npre_items[item_num].pre_hzs, w_buf + hzs_len, pre_len);
    257       npre_items[item_num].psb =
    258         ngram.get_uni_psb((size_t)(w_buf - buf_ - start_pos_[word_len - 1])
    259                           / word_len + start_id_[word_len - 1]);
    260       npre_items[item_num].his_len = hzs_len;
    261       item_num++;
    262       w_buf += word_len;
    263     }
    264   }
    265 
    266   size_t new_num = 0;
    267   for (size_t i = 0; i < item_num; i++) {
    268     // Try to find it in the existing items
    269     size_t e_pos;
    270     for (e_pos = 1; e_pos <= b4_used; e_pos++) {
    271       if (utf16_strncmp((*(npre_items - e_pos)).pre_hzs, npre_items[i].pre_hzs,
    272                         kMaxPredictSize) == 0)
    273         break;
    274     }
    275     if (e_pos <= b4_used)
    276       continue;
    277 
    278     // If not found, append it to the buffer
    279     npre_items[new_num] = npre_items[i];
    280     new_num++;
    281   }
    282 
    283   return new_num;
    284 }
    285 
    286 uint16 DictList::get_lemma_str(LemmaIdType id_lemma, char16 *str_buf,
    287                                uint16 str_max) {
    288   if (!initialized_ || id_lemma >= start_id_[kMaxLemmaSize] || NULL == str_buf
    289       || str_max <= 1)
    290     return 0;
    291 
    292   // Find the range
    293   for (uint16 i = 0; i < kMaxLemmaSize; i++) {
    294     if (i + 1 > str_max - 1)
    295       return 0;
    296     if (start_id_[i] <= id_lemma && start_id_[i + 1] > id_lemma) {
    297       size_t id_span = id_lemma - start_id_[i];
    298 
    299       uint16 *buf = buf_ + start_pos_[i] + id_span * (i + 1);
    300       for (uint16 len = 0; len <= i; len++) {
    301         str_buf[len] = buf[len];
    302       }
    303       str_buf[i+1] = (char16)'\0';
    304       return i + 1;
    305     }
    306   }
    307   return 0;
    308 }
    309 
    310 uint16 DictList::get_splids_for_hanzi(char16 hanzi, uint16 half_splid,
    311                                       uint16 *splids, uint16 max_splids) {
    312   char16 *hz_found = static_cast<char16*>
    313       (mybsearch(&hanzi, scis_hz_, scis_num_, sizeof(char16), cmp_hanzis_1));
    314   assert(NULL != hz_found && hanzi == *hz_found);
    315 
    316   // Move to the first one.
    317   while (hz_found > scis_hz_ && hanzi == *(hz_found - 1))
    318     hz_found--;
    319 
    320   // First try to found if strict comparison result is not zero.
    321   char16 *hz_f = hz_found;
    322   bool strict = false;
    323   while (hz_f < scis_hz_ + scis_num_ && hanzi == *hz_f) {
    324     uint16 pos = hz_f - scis_hz_;
    325     if (0 == half_splid || scis_splid_[pos].half_splid == half_splid) {
    326       strict = true;
    327     }
    328     hz_f++;
    329   }
    330 
    331   uint16 found_num = 0;
    332   while (hz_found < scis_hz_ + scis_num_ && hanzi == *hz_found) {
    333     uint16 pos = hz_found - scis_hz_;
    334     if (0 == half_splid ||
    335         (strict && scis_splid_[pos].half_splid == half_splid) ||
    336         (!strict && spl_trie_->half_full_compatible(half_splid,
    337         scis_splid_[pos].full_splid))) {
    338       assert(found_num + 1 < max_splids);
    339       splids[found_num] = scis_splid_[pos].full_splid;
    340       found_num++;
    341     }
    342     hz_found++;
    343   }
    344 
    345   return found_num;
    346 }
    347 
    348 LemmaIdType DictList::get_lemma_id(const char16 *str, uint16 str_len) {
    349   if (NULL == str || str_len > kMaxLemmaSize)
    350     return 0;
    351 
    352   char16 *found = find_pos_startedbyhzs(str, str_len, cmp_func_[str_len - 1]);
    353   if (NULL == found)
    354     return 0;
    355 
    356   assert(found > buf_);
    357   assert(static_cast<size_t>(found - buf_) >= start_pos_[str_len - 1]);
    358   return static_cast<LemmaIdType>
    359       (start_id_[str_len - 1] +
    360        (found - buf_ - start_pos_[str_len - 1]) / str_len);
    361 }
    362 
    363 void DictList::convert_to_hanzis(char16 *str, uint16 str_len) {
    364   assert(NULL != str);
    365 
    366   for (uint16 str_pos = 0; str_pos < str_len; str_pos++) {
    367     str[str_pos] = scis_hz_[str[str_pos]];
    368   }
    369 }
    370 
    371 void DictList::convert_to_scis_ids(char16 *str, uint16 str_len) {
    372   assert(NULL != str);
    373 
    374   for (uint16 str_pos = 0; str_pos < str_len; str_pos++) {
    375     str[str_pos] = 0x100;
    376   }
    377 }
    378 
    379 bool DictList::save_list(FILE *fp) {
    380   if (!initialized_ || NULL == fp)
    381     return false;
    382 
    383   if (NULL == buf_ || 0 == start_pos_[kMaxLemmaSize] ||
    384       NULL == scis_hz_ || NULL == scis_splid_ || 0 == scis_num_)
    385     return false;
    386 
    387   if (fwrite(&scis_num_, sizeof(size_t), 1, fp) != 1)
    388     return false;
    389 
    390   if (fwrite(start_pos_, sizeof(size_t), kMaxLemmaSize + 1, fp) !=
    391       kMaxLemmaSize + 1)
    392     return false;
    393 
    394   if (fwrite(start_id_, sizeof(size_t), kMaxLemmaSize + 1, fp) !=
    395       kMaxLemmaSize + 1)
    396     return false;
    397 
    398   if (fwrite(scis_hz_, sizeof(char16), scis_num_, fp) != scis_num_)
    399     return false;
    400 
    401   if (fwrite(scis_splid_, sizeof(SpellingId), scis_num_, fp) != scis_num_)
    402     return false;
    403 
    404   if (fwrite(buf_, sizeof(char16), start_pos_[kMaxLemmaSize], fp) !=
    405       start_pos_[kMaxLemmaSize])
    406     return false;
    407 
    408   return true;
    409 }
    410 
    411 bool DictList::load_list(FILE *fp) {
    412   if (NULL == fp)
    413     return false;
    414 
    415   initialized_ = false;
    416 
    417   if (fread(&scis_num_, sizeof(size_t), 1, fp) != 1)
    418     return false;
    419 
    420   if (fread(start_pos_, sizeof(size_t), kMaxLemmaSize + 1, fp) !=
    421       kMaxLemmaSize + 1)
    422     return false;
    423 
    424   if (fread(start_id_, sizeof(size_t), kMaxLemmaSize + 1, fp) !=
    425       kMaxLemmaSize + 1)
    426     return false;
    427 
    428   free_resource();
    429 
    430   if (!alloc_resource(start_pos_[kMaxLemmaSize], scis_num_))
    431     return false;
    432 
    433   if (fread(scis_hz_, sizeof(char16), scis_num_, fp) != scis_num_)
    434     return false;
    435 
    436   if (fread(scis_splid_, sizeof(SpellingId), scis_num_, fp) != scis_num_)
    437     return false;
    438 
    439   if (fread(buf_, sizeof(char16), start_pos_[kMaxLemmaSize], fp) !=
    440       start_pos_[kMaxLemmaSize])
    441     return false;
    442 
    443   initialized_ = true;
    444   return true;
    445 }
    446 }  // namespace ime_pinyin
    447