Home | History | Annotate | Download | only in marisa_alpha
      1 #ifndef MARISA_ALPHA_TRIE_INLINE_H_
      2 #define MARISA_ALPHA_TRIE_INLINE_H_
      3 
      4 #include <stdexcept>
      5 
      6 #include "cell.h"
      7 
      8 namespace marisa_alpha {
      9 
     10 inline std::string Trie::operator[](UInt32 key_id) const {
     11   std::string key;
     12   restore(key_id, &key);
     13   return key;
     14 }
     15 
     16 inline UInt32 Trie::operator[](const char *str) const {
     17   return lookup(str);
     18 }
     19 
     20 inline UInt32 Trie::operator[](const std::string &str) const {
     21   return lookup(str);
     22 }
     23 
     24 inline UInt32 Trie::lookup(const std::string &str) const {
     25   return lookup(str.c_str(), str.length());
     26 }
     27 
     28 inline std::size_t Trie::find(const std::string &str,
     29     UInt32 *key_ids, std::size_t *key_lengths,
     30     std::size_t max_num_results) const {
     31   return find(str.c_str(), str.length(),
     32       key_ids, key_lengths, max_num_results);
     33 }
     34 
     35 inline std::size_t Trie::find(const std::string &str,
     36     std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
     37     std::size_t max_num_results) const {
     38   return find(str.c_str(), str.length(),
     39       key_ids, key_lengths, max_num_results);
     40 }
     41 
     42 inline UInt32 Trie::find_first(const std::string &str,
     43     std::size_t *key_length) const {
     44   return find_first(str.c_str(), str.length(), key_length);
     45 }
     46 
     47 inline UInt32 Trie::find_last(const std::string &str,
     48     std::size_t *key_length) const {
     49   return find_last(str.c_str(), str.length(), key_length);
     50 }
     51 
     52 template <typename T>
     53 inline std::size_t Trie::find_callback(const char *str,
     54     T callback) const {
     55   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
     56   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
     57   return find_callback_<CQuery>(CQuery(str), callback);
     58 }
     59 
     60 template <typename T>
     61 inline std::size_t Trie::find_callback(const char *ptr, std::size_t length,
     62     T callback) const {
     63   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
     64   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
     65       MARISA_ALPHA_PARAM_ERROR);
     66   return find_callback_<const Query &>(Query(ptr, length), callback);
     67 }
     68 
     69 template <typename T>
     70 inline std::size_t Trie::find_callback(const std::string &str,
     71     T callback) const {
     72   return find_callback(str.c_str(), str.length(), callback);
     73 }
     74 
     75 inline std::size_t Trie::predict(const std::string &str,
     76     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
     77   return predict(str.c_str(), str.length(), key_ids, keys, max_num_results);
     78 }
     79 
     80 inline std::size_t Trie::predict(const std::string &str,
     81     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
     82     std::size_t max_num_results) const {
     83   return predict(str.c_str(), str.length(), key_ids, keys, max_num_results);
     84 }
     85 
     86 inline std::size_t Trie::predict_breadth_first(const std::string &str,
     87     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
     88   return predict_breadth_first(str.c_str(), str.length(),
     89       key_ids, keys, max_num_results);
     90 }
     91 
     92 inline std::size_t Trie::predict_breadth_first(const std::string &str,
     93     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
     94     std::size_t max_num_results) const {
     95   return predict_breadth_first(str.c_str(), str.length(),
     96       key_ids, keys, max_num_results);
     97 }
     98 
     99 inline std::size_t Trie::predict_depth_first(const std::string &str,
    100     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
    101   return predict_depth_first(str.c_str(), str.length(),
    102       key_ids, keys, max_num_results);
    103 }
    104 
    105 inline std::size_t Trie::predict_depth_first(const std::string &str,
    106     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
    107     std::size_t max_num_results) const {
    108   return predict_depth_first(str.c_str(), str.length(),
    109       key_ids, keys, max_num_results);
    110 }
    111 
    112 template <typename T>
    113 inline std::size_t Trie::predict_callback(
    114     const char *str, T callback) const {
    115   return predict_callback_<CQuery>(CQuery(str), callback);
    116 }
    117 
    118 template <typename T>
    119 inline std::size_t Trie::predict_callback(
    120     const char *ptr, std::size_t length,
    121     T callback) const {
    122   return predict_callback_<const Query &>(Query(ptr, length), callback);
    123 }
    124 
    125 template <typename T>
    126 inline std::size_t Trie::predict_callback(
    127     const std::string &str, T callback) const {
    128   return predict_callback(str.c_str(), str.length(), callback);
    129 }
    130 
    131 inline bool Trie::empty() const {
    132   return louds_.empty();
    133 }
    134 
    135 inline std::size_t Trie::num_keys() const {
    136   return num_keys_;
    137 }
    138 
    139 inline UInt32 Trie::notfound() {
    140   return MARISA_ALPHA_NOT_FOUND;
    141 }
    142 
    143 inline std::size_t Trie::mismatch() {
    144   return MARISA_ALPHA_MISMATCH;
    145 }
    146 
    147 template <typename T>
    148 inline bool Trie::find_child(UInt32 &node, T query,
    149     std::size_t &pos) const {
    150   UInt32 louds_pos = get_child(node);
    151   if (!louds_[louds_pos]) {
    152     return false;
    153   }
    154   node = louds_pos_to_node(louds_pos, node);
    155   UInt32 link_id = MARISA_ALPHA_UINT32_MAX;
    156   do {
    157     if (has_link(node)) {
    158       if (link_id == MARISA_ALPHA_UINT32_MAX) {
    159         link_id = get_link_id(node);
    160       } else {
    161         ++link_id;
    162       }
    163       std::size_t next_pos = has_trie() ?
    164           trie_->trie_match<T>(get_link(node, link_id), query, pos) :
    165           tail_match<T>(node, link_id, query, pos);
    166       if (next_pos == mismatch()) {
    167         return false;
    168       } else if (next_pos != pos) {
    169         pos = next_pos;
    170         return true;
    171       }
    172     } else if (labels_[node] == query[pos]) {
    173       ++pos;
    174       return true;
    175     }
    176     ++node;
    177     ++louds_pos;
    178   } while (louds_[louds_pos]);
    179   return false;
    180 }
    181 
    182 template <typename T, typename U>
    183 std::size_t Trie::find_callback_(T query, U callback) const try {
    184   std::size_t count = 0;
    185   UInt32 node = 0;
    186   std::size_t pos = 0;
    187   do {
    188     if (terminal_flags_[node]) {
    189       ++count;
    190       if (!callback(node_to_key_id(node), pos)) {
    191         return count;
    192       }
    193     }
    194   } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
    195   return count;
    196 } catch (const std::bad_alloc &) {
    197   MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
    198 } catch (const std::length_error &) {
    199   MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
    200 }
    201 
    202 template <typename T>
    203 inline bool Trie::predict_child(UInt32 &node, T query, std::size_t &pos,
    204     std::string *key) const {
    205   UInt32 louds_pos = get_child(node);
    206   if (!louds_[louds_pos]) {
    207     return false;
    208   }
    209   node = louds_pos_to_node(louds_pos, node);
    210   UInt32 link_id = MARISA_ALPHA_UINT32_MAX;
    211   do {
    212     if (has_link(node)) {
    213       if (link_id == MARISA_ALPHA_UINT32_MAX) {
    214         link_id = get_link_id(node);
    215       } else {
    216         ++link_id;
    217       }
    218       std::size_t next_pos = has_trie() ?
    219           trie_->trie_prefix_match<T>(
    220               get_link(node, link_id), query, pos, key) :
    221           tail_prefix_match<T>(node, link_id, query, pos, key);
    222       if (next_pos == mismatch()) {
    223         return false;
    224       } else if (next_pos != pos) {
    225         pos = next_pos;
    226         return true;
    227       }
    228     } else if (labels_[node] == query[pos]) {
    229       ++pos;
    230       return true;
    231     }
    232     ++node;
    233     ++louds_pos;
    234   } while (louds_[louds_pos]);
    235   return false;
    236 }
    237 
    238 template <typename T, typename U>
    239 std::size_t Trie::predict_callback_(T query, U callback) const try {
    240   std::string key;
    241   UInt32 node = 0;
    242   std::size_t pos = 0;
    243   while (!query.ends_at(pos)) {
    244     if (!predict_child<T>(node, query, pos, &key)) {
    245       return 0;
    246     }
    247   }
    248   query.insert(&key);
    249   std::size_t count = 0;
    250   if (terminal_flags_[node]) {
    251     ++count;
    252     if (!callback(node_to_key_id(node), key)) {
    253       return count;
    254     }
    255   }
    256   Cell cell;
    257   cell.set_louds_pos(get_child(node));
    258   if (!louds_[cell.louds_pos()]) {
    259     return count;
    260   }
    261   cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
    262   cell.set_key_id(node_to_key_id(cell.node()));
    263   cell.set_length(key.length());
    264   Vector<Cell> stack;
    265   stack.push_back(cell);
    266   std::size_t stack_pos = 1;
    267   while (stack_pos != 0) {
    268     Cell &cur = stack[stack_pos - 1];
    269     if (!louds_[cur.louds_pos()]) {
    270       cur.set_louds_pos(cur.louds_pos() + 1);
    271       --stack_pos;
    272       continue;
    273     }
    274     cur.set_louds_pos(cur.louds_pos() + 1);
    275     key.resize(cur.length());
    276     if (has_link(cur.node())) {
    277       if (has_trie()) {
    278         trie_->trie_restore(get_link(cur.node()), &key);
    279       } else {
    280         tail_restore(cur.node(), &key);
    281       }
    282     } else {
    283       key += labels_[cur.node()];
    284     }
    285     if (terminal_flags_[cur.node()]) {
    286       ++count;
    287       if (!callback(cur.key_id(), key)) {
    288         return count;
    289       }
    290       cur.set_key_id(cur.key_id() + 1);
    291     }
    292     if (stack_pos == stack.size()) {
    293       cell.set_louds_pos(get_child(cur.node()));
    294       cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
    295       cell.set_key_id(node_to_key_id(cell.node()));
    296       stack.push_back(cell);
    297     }
    298     stack[stack_pos].set_length(key.length());
    299     stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
    300     ++stack_pos;
    301   }
    302   return count;
    303 } catch (const std::bad_alloc &) {
    304   MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
    305 } catch (const std::length_error &) {
    306   MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
    307 }
    308 
    309 inline UInt32 Trie::key_id_to_node(UInt32 key_id) const {
    310   return terminal_flags_.select1(key_id);
    311 }
    312 
    313 inline UInt32 Trie::node_to_key_id(UInt32 node) const {
    314   return terminal_flags_.rank1(node);
    315 }
    316 
    317 inline UInt32 Trie::louds_pos_to_node(UInt32 louds_pos,
    318     UInt32 parent_node) const {
    319   return louds_pos - parent_node - 1;
    320 }
    321 
    322 inline UInt32 Trie::get_child(UInt32 node) const {
    323   return louds_.select0(node) + 1;
    324 }
    325 
    326 inline UInt32 Trie::get_parent(UInt32 node) const {
    327   return (node > num_first_branches_) ? (louds_.select1(node) - node - 1) : 0;
    328 }
    329 
    330 inline bool Trie::has_link(UInt32 node) const {
    331   return (link_flags_.empty()) ? false : link_flags_[node];
    332 }
    333 
    334 inline UInt32 Trie::get_link_id(UInt32 node) const {
    335   return link_flags_.rank1(node);
    336 }
    337 
    338 inline UInt32 Trie::get_link(UInt32 node) const {
    339   return get_link(node, get_link_id(node));
    340 }
    341 
    342 inline UInt32 Trie::get_link(UInt32 node, UInt32 link_id) const {
    343   return (links_[link_id] * 256) + labels_[node];
    344 }
    345 
    346 inline bool Trie::has_link() const {
    347   return !link_flags_.empty();
    348 }
    349 
    350 inline bool Trie::has_trie() const {
    351   return trie_.get() != NULL;
    352 }
    353 
    354 inline bool Trie::has_tail() const {
    355   return !tail_.empty();
    356 }
    357 
    358 }  // namespace marisa_alpha
    359 
    360 #endif  // MARISA_ALPHA_TRIE_INLINE_H_
    361