Home | History | Annotate | Download | only in marisa_alpha
      1 #include <algorithm>
      2 #include <functional>
      3 #include <queue>
      4 #include <stdexcept>
      5 
      6 #include "range.h"
      7 #include "trie.h"
      8 
      9 namespace marisa_alpha {
     10 
     11 void Trie::build(const char * const *keys, std::size_t num_keys,
     12     const std::size_t *key_lengths, const double *key_weights,
     13     UInt32 *key_ids, int flags) {
     14   MARISA_ALPHA_THROW_IF((keys == NULL) && (num_keys != 0),
     15       MARISA_ALPHA_PARAM_ERROR);
     16   Vector<Key<String> > temp_keys;
     17   temp_keys.resize(num_keys);
     18   for (std::size_t i = 0; i < temp_keys.size(); ++i) {
     19     MARISA_ALPHA_THROW_IF(keys[i] == NULL, MARISA_ALPHA_PARAM_ERROR);
     20     std::size_t length = 0;
     21     if (key_lengths == NULL) {
     22       while (keys[i][length] != '\0') {
     23         ++length;
     24       }
     25     } else {
     26       length = key_lengths[i];
     27     }
     28     MARISA_ALPHA_THROW_IF(length > MARISA_ALPHA_MAX_LENGTH,
     29         MARISA_ALPHA_SIZE_ERROR);
     30     temp_keys[i].set_str(String(keys[i], length));
     31     temp_keys[i].set_weight((key_weights != NULL) ? key_weights[i] : 1.0);
     32   }
     33   build_trie(temp_keys, key_ids, flags);
     34 }
     35 
     36 void Trie::build(const std::vector<std::string> &keys,
     37     std::vector<UInt32> *key_ids, int flags) {
     38   Vector<Key<String> > temp_keys;
     39   temp_keys.resize(keys.size());
     40   for (std::size_t i = 0; i < temp_keys.size(); ++i) {
     41     MARISA_ALPHA_THROW_IF(keys[i].length() > MARISA_ALPHA_MAX_LENGTH,
     42         MARISA_ALPHA_SIZE_ERROR);
     43     temp_keys[i].set_str(String(keys[i].c_str(), keys[i].length()));
     44     temp_keys[i].set_weight(1.0);
     45   }
     46   build_trie(temp_keys, key_ids, flags);
     47 }
     48 
     49 void Trie::build(const std::vector<std::pair<std::string, double> > &keys,
     50     std::vector<UInt32> *key_ids, int flags) {
     51   Vector<Key<String> > temp_keys;
     52   temp_keys.resize(keys.size());
     53   for (std::size_t i = 0; i < temp_keys.size(); ++i) {
     54     MARISA_ALPHA_THROW_IF(keys[i].first.length() > MARISA_ALPHA_MAX_LENGTH,
     55         MARISA_ALPHA_SIZE_ERROR);
     56     temp_keys[i].set_str(String(
     57         keys[i].first.c_str(), keys[i].first.length()));
     58     temp_keys[i].set_weight(keys[i].second);
     59   }
     60   build_trie(temp_keys, key_ids, flags);
     61 }
     62 
     63 void Trie::build_trie(Vector<Key<String> > &keys,
     64     std::vector<UInt32> *key_ids, int flags) {
     65   if (key_ids == NULL) {
     66     build_trie(keys, static_cast<UInt32 *>(NULL), flags);
     67     return;
     68   }
     69   try {
     70     std::vector<UInt32> temp_key_ids(keys.size());
     71     build_trie(keys, temp_key_ids.empty() ? NULL : &temp_key_ids[0], flags);
     72     key_ids->swap(temp_key_ids);
     73   } catch (const std::bad_alloc &) {
     74     MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
     75   } catch (const std::length_error &) {
     76     MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
     77   }
     78 }
     79 
     80 void Trie::build_trie(Vector<Key<String> > &keys,
     81     UInt32 *key_ids, int flags) {
     82   Trie temp;
     83   Vector<UInt32> terminals;
     84   Progress progress(flags);
     85   MARISA_ALPHA_THROW_IF(!progress.is_valid(), MARISA_ALPHA_PARAM_ERROR);
     86   temp.build_trie(keys, &terminals, progress);
     87 
     88   typedef std::pair<UInt32, UInt32> TerminalIdPair;
     89   Vector<TerminalIdPair> pairs;
     90   pairs.resize(terminals.size());
     91   for (UInt32 i = 0; i < pairs.size(); ++i) {
     92     pairs[i].first = terminals[i];
     93     pairs[i].second = i;
     94   }
     95   terminals.clear();
     96   std::sort(pairs.begin(), pairs.end());
     97 
     98   UInt32 node = 0;
     99   for (UInt32 i = 0; i < pairs.size(); ++i) {
    100     while (node < pairs[i].first) {
    101       temp.terminal_flags_.push_back(false);
    102       ++node;
    103     }
    104     if (node == pairs[i].first) {
    105       temp.terminal_flags_.push_back(true);
    106       ++node;
    107     }
    108   }
    109   while (node < temp.labels_.size()) {
    110     temp.terminal_flags_.push_back(false);
    111     ++node;
    112   }
    113   terminal_flags_.push_back(false);
    114   temp.terminal_flags_.build();
    115   temp.terminal_flags_.clear_select0s();
    116   progress.test_total_size(temp.terminal_flags_.total_size());
    117 
    118   if (key_ids != NULL) {
    119     for (UInt32 i = 0; i < pairs.size(); ++i) {
    120       key_ids[pairs[i].second] = temp.node_to_key_id(pairs[i].first);
    121     }
    122   }
    123   MARISA_ALPHA_THROW_IF(progress.total_size() != temp.total_size(),
    124       MARISA_ALPHA_UNEXPECTED_ERROR);
    125   temp.swap(this);
    126 }
    127 
    128 template <typename T>
    129 void Trie::build_trie(Vector<Key<T> > &keys,
    130     Vector<UInt32> *terminals, Progress &progress) {
    131   build_cur(keys, terminals, progress);
    132   progress.test_total_size(louds_.total_size());
    133   progress.test_total_size(sizeof(num_first_branches_));
    134   progress.test_total_size(sizeof(num_keys_));
    135   if (link_flags_.empty()) {
    136     labels_.shrink();
    137     progress.test_total_size(labels_.total_size());
    138     progress.test_total_size(link_flags_.total_size());
    139     progress.test_total_size(links_.total_size());
    140     progress.test_total_size(tail_.total_size());
    141     return;
    142   }
    143 
    144   Vector<UInt32> next_terminals;
    145   build_next(keys, &next_terminals, progress);
    146 
    147   if (has_trie()) {
    148     progress.test_total_size(trie_->terminal_flags_.total_size());
    149   } else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
    150     labels_.push_back('\0');
    151     link_flags_.push_back(true);
    152   }
    153   link_flags_.build();
    154 
    155   for (UInt32 i = 0; i < next_terminals.size(); ++i) {
    156     labels_[link_flags_.select1(i)] = (UInt8)(next_terminals[i] % 256);
    157     next_terminals[i] /= 256;
    158   }
    159   link_flags_.clear_select0s();
    160   if (has_trie() || (tail_.mode() == MARISA_ALPHA_TEXT_TAIL)) {
    161     link_flags_.clear_select1s();
    162   }
    163 
    164   links_.build(next_terminals);
    165   labels_.shrink();
    166   progress.test_total_size(labels_.total_size());
    167   progress.test_total_size(link_flags_.total_size());
    168   progress.test_total_size(links_.total_size());
    169   progress.test_total_size(tail_.total_size());
    170 }
    171 
    172 template <typename T>
    173 void Trie::build_cur(Vector<Key<T> > &keys,
    174     Vector<UInt32> *terminals, Progress &progress) try {
    175   num_keys_ = sort_keys(keys);
    176   louds_.push_back(true);
    177   louds_.push_back(false);
    178   labels_.push_back('\0');
    179   link_flags_.push_back(false);
    180 
    181   Vector<Key<T> > rest_keys;
    182   std::queue<Range> queue;
    183   Vector<WRange> wranges;
    184   queue.push(Range(0, (UInt32)keys.size(), 0));
    185   while (!queue.empty()) {
    186     const UInt32 node = (UInt32)(link_flags_.size() - queue.size());
    187     Range range = queue.front();
    188     queue.pop();
    189 
    190     while ((range.begin() < range.end()) &&
    191         (keys[range.begin()].str().length() == range.pos())) {
    192       keys[range.begin()].set_terminal(node);
    193       range.set_begin(range.begin() + 1);
    194     }
    195     if (range.begin() == range.end()) {
    196       louds_.push_back(false);
    197       continue;
    198     }
    199 
    200     wranges.clear();
    201     double weight = keys[range.begin()].weight();
    202     for (UInt32 i = range.begin() + 1; i < range.end(); ++i) {
    203       if (keys[i - 1].str()[range.pos()] != keys[i].str()[range.pos()]) {
    204         wranges.push_back(WRange(range.begin(), i, range.pos(), weight));
    205         range.set_begin(i);
    206         weight = 0.0;
    207       }
    208       weight += keys[i].weight();
    209     }
    210     wranges.push_back(WRange(range, weight));
    211     if (progress.order() == MARISA_ALPHA_WEIGHT_ORDER) {
    212       std::stable_sort(wranges.begin(), wranges.end(), std::greater<WRange>());
    213     }
    214     if (node == 0) {
    215       num_first_branches_ = wranges.size();
    216     }
    217     for (UInt32 i = 0; i < wranges.size(); ++i) {
    218       const WRange &wrange = wranges[i];
    219       UInt32 pos = wrange.pos() + 1;
    220       if ((progress.tail() != MARISA_ALPHA_WITHOUT_TAIL) ||
    221           !progress.is_last()) {
    222         while (pos < keys[wrange.begin()].str().length()) {
    223           UInt32 j;
    224           for (j = wrange.begin() + 1; j < wrange.end(); ++j) {
    225             if (keys[j - 1].str()[pos] != keys[j].str()[pos]) {
    226               break;
    227             }
    228           }
    229           if (j < wrange.end()) {
    230             break;
    231           }
    232           ++pos;
    233         }
    234       }
    235       if ((progress.trie() != MARISA_ALPHA_PATRICIA_TRIE) &&
    236           (pos != keys[wrange.end() - 1].str().length())) {
    237         pos = wrange.pos() + 1;
    238       }
    239       louds_.push_back(true);
    240       if (pos == wrange.pos() + 1) {
    241         labels_.push_back(keys[wrange.begin()].str()[wrange.pos()]);
    242         link_flags_.push_back(false);
    243       } else {
    244         labels_.push_back('\0');
    245         link_flags_.push_back(true);
    246         Key<T> rest_key;
    247         rest_key.set_str(keys[wrange.begin()].str().substr(
    248             wrange.pos(), pos - wrange.pos()));
    249         rest_key.set_weight(wrange.weight());
    250         rest_keys.push_back(rest_key);
    251       }
    252       wranges[i].set_pos(pos);
    253       queue.push(wranges[i].range());
    254     }
    255     louds_.push_back(false);
    256   }
    257   louds_.push_back(false);
    258   louds_.build();
    259   if (progress.trie_id() != 0) {
    260     louds_.clear_select0s();
    261   }
    262   if (rest_keys.empty()) {
    263     link_flags_.clear();
    264   }
    265 
    266   build_terminals(keys, terminals);
    267   keys.swap(&rest_keys);
    268 } catch (const std::bad_alloc &) {
    269   MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
    270 } catch (const std::length_error &) {
    271   MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
    272 }
    273 
    274 void Trie::build_next(Vector<Key<String> > &keys,
    275     Vector<UInt32> *terminals, Progress &progress) {
    276   if (progress.is_last()) {
    277     Vector<String> strs;
    278     strs.resize(keys.size());
    279     for (UInt32 i = 0; i < strs.size(); ++i) {
    280       strs[i] = keys[i].str();
    281     }
    282     tail_.build(strs, terminals, progress.tail());
    283     return;
    284   }
    285   Vector<Key<RString> > rkeys;
    286   rkeys.resize(keys.size());
    287   for (UInt32 i = 0; i < rkeys.size(); ++i) {
    288     rkeys[i].set_str(RString(keys[i].str()));
    289     rkeys[i].set_weight(keys[i].weight());
    290   }
    291   keys.clear();
    292   trie_.reset(new (std::nothrow) Trie);
    293   MARISA_ALPHA_THROW_IF(!has_trie(), MARISA_ALPHA_MEMORY_ERROR);
    294   trie_->build_trie(rkeys, terminals, ++progress);
    295 }
    296 
    297 void Trie::build_next(Vector<Key<RString> > &rkeys,
    298     Vector<UInt32> *terminals, Progress &progress) {
    299   if (progress.is_last()) {
    300     Vector<String> strs;
    301     strs.resize(rkeys.size());
    302     for (UInt32 i = 0; i < strs.size(); ++i) {
    303       strs[i] = String(rkeys[i].str().ptr(), rkeys[i].str().length());
    304     }
    305     tail_.build(strs, terminals, progress.tail());
    306     return;
    307   }
    308   trie_.reset(new (std::nothrow) Trie);
    309   MARISA_ALPHA_THROW_IF(!has_trie(), MARISA_ALPHA_MEMORY_ERROR);
    310   trie_->build_trie(rkeys, terminals, ++progress);
    311 }
    312 
    313 template <typename T>
    314 UInt32 Trie::sort_keys(Vector<Key<T> > &keys) const {
    315   if (keys.empty()) {
    316     return 0;
    317   }
    318   for (UInt32 i = 0; i < keys.size(); ++i) {
    319     keys[i].set_id(i);
    320   }
    321   std::sort(keys.begin(), keys.end());
    322   UInt32 count = 1;
    323   for (UInt32 i = 1; i < keys.size(); ++i) {
    324     if (keys[i - 1].str() != keys[i].str()) {
    325       ++count;
    326     }
    327   }
    328   return count;
    329 }
    330 
    331 template <typename T>
    332 void Trie::build_terminals(const Vector<Key<T> > &keys,
    333     Vector<UInt32> *terminals) const {
    334   Vector<UInt32> temp_terminals;
    335   temp_terminals.resize(keys.size());
    336   for (UInt32 i = 0; i < keys.size(); ++i) {
    337     temp_terminals[keys[i].id()] = keys[i].terminal();
    338   }
    339   temp_terminals.swap(terminals);
    340 }
    341 
    342 }  // namespace marisa_alpha
    343