Home | History | Annotate | Download | only in marisa_alpha
      1 #include <algorithm>
      2 #include <stdexcept>
      3 
      4 #include "trie.h"
      5 
      6 namespace marisa_alpha {
      7 namespace {
      8 
      9 template <typename T, typename U>
     10 class PredictCallback {
     11  public:
     12   PredictCallback(T key_ids, U keys, std::size_t max_num_results)
     13       : key_ids_(key_ids), keys_(keys),
     14         max_num_results_(max_num_results), num_results_(0) {}
     15   PredictCallback(const PredictCallback &callback)
     16       : key_ids_(callback.key_ids_), keys_(callback.keys_),
     17         max_num_results_(callback.max_num_results_),
     18         num_results_(callback.num_results_) {}
     19 
     20   bool operator()(marisa_alpha::UInt32 key_id, const std::string &key) {
     21     if (key_ids_.is_valid()) {
     22       key_ids_.insert(num_results_, key_id);
     23     }
     24     if (keys_.is_valid()) {
     25       keys_.insert(num_results_, key);
     26     }
     27     return ++num_results_ < max_num_results_;
     28   }
     29 
     30  private:
     31   T key_ids_;
     32   U keys_;
     33   const std::size_t max_num_results_;
     34   std::size_t num_results_;
     35 
     36   // Disallows assignment.
     37   PredictCallback &operator=(const PredictCallback &);
     38 };
     39 
     40 }  // namespace
     41 
     42 std::string Trie::restore(UInt32 key_id) const {
     43   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
     44   MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR);
     45   std::string key;
     46   restore_(key_id, &key);
     47   return key;
     48 }
     49 
     50 void Trie::restore(UInt32 key_id, std::string *key) const {
     51   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
     52   MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR);
     53   MARISA_ALPHA_THROW_IF(key == NULL, MARISA_ALPHA_PARAM_ERROR);
     54   restore_(key_id, key);
     55 }
     56 
     57 std::size_t Trie::restore(UInt32 key_id, char *key_buf,
     58     std::size_t key_buf_size) const {
     59   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
     60   MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR);
     61   MARISA_ALPHA_THROW_IF((key_buf == NULL) && (key_buf_size != 0),
     62       MARISA_ALPHA_PARAM_ERROR);
     63   return restore_(key_id, key_buf, key_buf_size);
     64 }
     65 
     66 UInt32 Trie::lookup(const char *str) const {
     67   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
     68   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
     69   return lookup_<CQuery>(CQuery(str));
     70 }
     71 
     72 UInt32 Trie::lookup(const char *ptr, std::size_t length) const {
     73   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
     74   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
     75       MARISA_ALPHA_PARAM_ERROR);
     76   return lookup_<const Query &>(Query(ptr, length));
     77 }
     78 
     79 std::size_t Trie::find(const char *str,
     80     UInt32 *key_ids, std::size_t *key_lengths,
     81     std::size_t max_num_results) const {
     82   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
     83   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
     84   return find_<CQuery>(CQuery(str),
     85       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
     86 }
     87 
     88 std::size_t Trie::find(const char *ptr, std::size_t length,
     89     UInt32 *key_ids, std::size_t *key_lengths,
     90     std::size_t max_num_results) const {
     91   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
     92   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
     93       MARISA_ALPHA_PARAM_ERROR);
     94   return find_<const Query &>(Query(ptr, length),
     95       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
     96 }
     97 
     98 std::size_t Trie::find(const char *str,
     99     std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
    100     std::size_t max_num_results) const {
    101   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    102   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
    103   return find_<CQuery>(CQuery(str),
    104       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
    105 }
    106 
    107 std::size_t Trie::find(const char *ptr, std::size_t length,
    108     std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
    109     std::size_t max_num_results) const {
    110   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    111   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
    112       MARISA_ALPHA_PARAM_ERROR);
    113   return find_<const Query &>(Query(ptr, length),
    114       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
    115 }
    116 
    117 UInt32 Trie::find_first(const char *str,
    118     std::size_t *key_length) const {
    119   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    120   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
    121   return find_first_<CQuery>(CQuery(str), key_length);
    122 }
    123 
    124 UInt32 Trie::find_first(const char *ptr, std::size_t length,
    125     std::size_t *key_length) const {
    126   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    127   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
    128       MARISA_ALPHA_PARAM_ERROR);
    129   return find_first_<const Query &>(Query(ptr, length), key_length);
    130 }
    131 
    132 UInt32 Trie::find_last(const char *str,
    133     std::size_t *key_length) const {
    134   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    135   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
    136   return find_last_<CQuery>(CQuery(str), key_length);
    137 }
    138 
    139 UInt32 Trie::find_last(const char *ptr, std::size_t length,
    140     std::size_t *key_length) const {
    141   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    142   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
    143       MARISA_ALPHA_PARAM_ERROR);
    144   return find_last_<const Query &>(Query(ptr, length), key_length);
    145 }
    146 
    147 std::size_t Trie::predict(const char *str,
    148     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
    149   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    150   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
    151   return (keys == NULL) ?
    152       predict_breadth_first(str, key_ids, keys, max_num_results) :
    153       predict_depth_first(str, key_ids, keys, max_num_results);
    154 }
    155 
    156 std::size_t Trie::predict(const char *ptr, std::size_t length,
    157     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
    158   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    159   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
    160       MARISA_ALPHA_PARAM_ERROR);
    161   return (keys == NULL) ?
    162       predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
    163       predict_depth_first(ptr, length, key_ids, keys, max_num_results);
    164 }
    165 
    166 std::size_t Trie::predict(const char *str,
    167     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
    168     std::size_t max_num_results) const {
    169   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    170   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
    171   return (keys == NULL) ?
    172       predict_breadth_first(str, key_ids, keys, max_num_results) :
    173       predict_depth_first(str, key_ids, keys, max_num_results);
    174 }
    175 
    176 std::size_t Trie::predict(const char *ptr, std::size_t length,
    177     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
    178     std::size_t max_num_results) const {
    179   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    180   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
    181       MARISA_ALPHA_PARAM_ERROR);
    182   return (keys == NULL) ?
    183       predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
    184       predict_depth_first(ptr, length, key_ids, keys, max_num_results);
    185 }
    186 
    187 std::size_t Trie::predict_breadth_first(const char *str,
    188     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
    189   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    190   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
    191   return predict_breadth_first_<CQuery>(CQuery(str),
    192       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    193 }
    194 
    195 std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
    196     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
    197   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    198   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
    199       MARISA_ALPHA_PARAM_ERROR);
    200   return predict_breadth_first_<const Query &>(Query(ptr, length),
    201       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    202 }
    203 
    204 std::size_t Trie::predict_breadth_first(const char *str,
    205     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
    206     std::size_t max_num_results) const {
    207   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    208   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
    209   return predict_breadth_first_<CQuery>(CQuery(str),
    210       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    211 }
    212 
    213 std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
    214     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
    215     std::size_t max_num_results) const {
    216   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    217   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
    218       MARISA_ALPHA_PARAM_ERROR);
    219   return predict_breadth_first_<const Query &>(Query(ptr, length),
    220       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    221 }
    222 
    223 std::size_t Trie::predict_depth_first(const char *str,
    224     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
    225   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    226   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
    227   return predict_depth_first_<CQuery>(CQuery(str),
    228       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    229 }
    230 
    231 std::size_t Trie::predict_depth_first(const char *ptr, std::size_t length,
    232     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
    233   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    234   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
    235       MARISA_ALPHA_PARAM_ERROR);
    236   return predict_depth_first_<const Query &>(Query(ptr, length),
    237       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    238 }
    239 
    240 std::size_t Trie::predict_depth_first(
    241     const char *str, std::vector<UInt32> *key_ids,
    242     std::vector<std::string> *keys, std::size_t max_num_results) const {
    243   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    244   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
    245   return predict_depth_first_<CQuery>(CQuery(str),
    246       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    247 }
    248 
    249 std::size_t Trie::predict_depth_first(
    250     const char *ptr, std::size_t length, std::vector<UInt32> *key_ids,
    251     std::vector<std::string> *keys, std::size_t max_num_results) const {
    252   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
    253   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
    254       MARISA_ALPHA_PARAM_ERROR);
    255   return predict_depth_first_<const Query &>(Query(ptr, length),
    256       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    257 }
    258 
    259 void Trie::restore_(UInt32 key_id, std::string *key) const {
    260   const std::size_t start_pos = key->length();
    261   try {
    262     UInt32 node = key_id_to_node(key_id);
    263     while (node != 0) {
    264       if (has_link(node)) {
    265         const std::size_t prev_pos = key->length();
    266         if (has_trie()) {
    267           trie_->trie_restore(get_link(node), key);
    268         } else {
    269           tail_restore(node, key);
    270         }
    271         std::reverse(key->begin() + prev_pos, key->end());
    272       } else {
    273         *key += labels_[node];
    274       }
    275       node = get_parent(node);
    276     }
    277     std::reverse(key->begin() + start_pos, key->end());
    278   } catch (const std::bad_alloc &) {
    279     key->resize(start_pos);
    280     MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
    281   } catch (const std::length_error &) {
    282     key->resize(start_pos);
    283     MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
    284   }
    285 }
    286 
    287 void Trie::trie_restore(UInt32 node, std::string *key) const {
    288   do {
    289     if (has_link(node)) {
    290       if (has_trie()) {
    291         trie_->trie_restore(get_link(node), key);
    292       } else {
    293         tail_restore(node, key);
    294       }
    295     } else {
    296       *key += labels_[node];
    297     }
    298     node = get_parent(node);
    299   } while (node != 0);
    300 }
    301 
    302 void Trie::tail_restore(UInt32 node, std::string *key) const {
    303   const UInt32 link_id = link_flags_.rank1(node);
    304   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
    305   if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
    306     const UInt32 length = (links_[link_id + 1] * 256)
    307         + labels_[link_flags_.select1(link_id + 1)] - offset;
    308     key->append(reinterpret_cast<const char *>(tail_[offset]), length);
    309   } else {
    310     key->append(reinterpret_cast<const char *>(tail_[offset]));
    311   }
    312 }
    313 
    314 std::size_t Trie::restore_(UInt32 key_id, char *key_buf,
    315     std::size_t key_buf_size) const {
    316   std::size_t pos = 0;
    317   UInt32 node = key_id_to_node(key_id);
    318   while (node != 0) {
    319     if (has_link(node)) {
    320       const std::size_t prev_pos = pos;
    321       if (has_trie()) {
    322         trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
    323       } else {
    324         tail_restore(node, key_buf, key_buf_size, pos);
    325       }
    326       if (pos < key_buf_size) {
    327         std::reverse(key_buf + prev_pos, key_buf + pos);
    328       }
    329     } else {
    330       if (pos < key_buf_size) {
    331         key_buf[pos] = labels_[node];
    332       }
    333       ++pos;
    334     }
    335     node = get_parent(node);
    336   }
    337   if (pos < key_buf_size) {
    338     key_buf[pos] = '\0';
    339     std::reverse(key_buf, key_buf + pos);
    340   }
    341   return pos;
    342 }
    343 
    344 void Trie::trie_restore(UInt32 node, char *key_buf,
    345     std::size_t key_buf_size, std::size_t &pos) const {
    346   do {
    347     if (has_link(node)) {
    348       if (has_trie()) {
    349         trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
    350       } else {
    351         tail_restore(node, key_buf, key_buf_size, pos);
    352       }
    353     } else {
    354       if (pos < key_buf_size) {
    355         key_buf[pos] = labels_[node];
    356       }
    357       ++pos;
    358     }
    359     node = get_parent(node);
    360   } while (node != 0);
    361 }
    362 
    363 void Trie::tail_restore(UInt32 node, char *key_buf,
    364     std::size_t key_buf_size, std::size_t &pos) const {
    365   const UInt32 link_id = link_flags_.rank1(node);
    366   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
    367   if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
    368     const UInt8 *ptr = tail_[offset];
    369     const UInt32 length = (links_[link_id + 1] * 256)
    370         + labels_[link_flags_.select1(link_id + 1)] - offset;
    371     for (UInt32 i = 0; i < length; ++i) {
    372       if (pos < key_buf_size) {
    373         key_buf[pos] = ptr[i];
    374       }
    375       ++pos;
    376     }
    377   } else {
    378     for (const UInt8 *str = tail_[offset]; *str != '\0'; ++str) {
    379       if (pos < key_buf_size) {
    380         key_buf[pos] = *str;
    381       }
    382       ++pos;
    383     }
    384   }
    385 }
    386 
    387 template <typename T>
    388 UInt32 Trie::lookup_(T query) const {
    389   UInt32 node = 0;
    390   std::size_t pos = 0;
    391   while (!query.ends_at(pos)) {
    392     if (!find_child<T>(node, query, pos)) {
    393       return notfound();
    394     }
    395   }
    396   return terminal_flags_[node] ? node_to_key_id(node) : notfound();
    397 }
    398 
    399 template <typename T>
    400 std::size_t Trie::trie_match(UInt32 node, T query,
    401     std::size_t pos) const {
    402   if (has_link(node)) {
    403     std::size_t next_pos;
    404     if (has_trie()) {
    405       next_pos = trie_->trie_match<T>(get_link(node), query, pos);
    406     } else {
    407       next_pos = tail_match<T>(node, get_link_id(node), query, pos);
    408     }
    409     if ((next_pos == mismatch()) || (next_pos == pos)) {
    410       return next_pos;
    411     }
    412     pos = next_pos;
    413   } else if (labels_[node] != query[pos]) {
    414     return pos;
    415   } else {
    416     ++pos;
    417   }
    418   node = get_parent(node);
    419   while (node != 0) {
    420     if (query.ends_at(pos)) {
    421       return mismatch();
    422     }
    423     if (has_link(node)) {
    424       std::size_t next_pos;
    425       if (has_trie()) {
    426         next_pos = trie_->trie_match<T>(get_link(node), query, pos);
    427       } else {
    428         next_pos = tail_match<T>(node, get_link_id(node), query, pos);
    429       }
    430       if ((next_pos == mismatch()) || (next_pos == pos)) {
    431         return mismatch();
    432       }
    433       pos = next_pos;
    434     } else if (labels_[node] != query[pos]) {
    435       return mismatch();
    436     } else {
    437       ++pos;
    438     }
    439     node = get_parent(node);
    440   }
    441   return pos;
    442 }
    443 
    444 template std::size_t Trie::trie_match<CQuery>(UInt32 node,
    445     CQuery query, std::size_t pos) const;
    446 template std::size_t Trie::trie_match<const Query &>(UInt32 node,
    447     const Query &query, std::size_t pos) const;
    448 
    449 template <typename T>
    450 std::size_t Trie::tail_match(UInt32 node, UInt32 link_id,
    451     T query, std::size_t pos) const {
    452   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
    453   const UInt8 *ptr = tail_[offset];
    454   if (*ptr != query[pos]) {
    455     return pos;
    456   } else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
    457     const UInt32 length = (links_[link_id + 1] * 256)
    458         + labels_[link_flags_.select1(link_id + 1)] - offset;
    459     for (UInt32 i = 1; i < length; ++i) {
    460       if (query.ends_at(pos + i) || (ptr[i] != query[pos + i])) {
    461         return mismatch();
    462       }
    463     }
    464     return pos + length;
    465   } else {
    466     for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
    467       if (query.ends_at(pos) || (*ptr != query[pos])) {
    468         return mismatch();
    469       }
    470     }
    471     return pos;
    472   }
    473 }
    474 
    475 template std::size_t Trie::tail_match<CQuery>(UInt32 node,
    476     UInt32 link_id, CQuery query, std::size_t pos) const;
    477 template std::size_t Trie::tail_match<const Query &>(UInt32 node,
    478     UInt32 link_id, const Query &query, std::size_t pos) const;
    479 
    480 template <typename T, typename U, typename V>
    481 std::size_t Trie::find_(T query, U key_ids, V key_lengths,
    482     std::size_t max_num_results) const try {
    483   if (max_num_results == 0) {
    484     return 0;
    485   }
    486   std::size_t count = 0;
    487   UInt32 node = 0;
    488   std::size_t pos = 0;
    489   do {
    490     if (terminal_flags_[node]) {
    491       if (key_ids.is_valid()) {
    492         key_ids.insert(count, node_to_key_id(node));
    493       }
    494       if (key_lengths.is_valid()) {
    495         key_lengths.insert(count, pos);
    496       }
    497       if (++count >= max_num_results) {
    498         return count;
    499       }
    500     }
    501   } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
    502   return count;
    503 } catch (const std::bad_alloc &) {
    504   MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
    505 } catch (const std::length_error &) {
    506   MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
    507 }
    508 
    509 template <typename T>
    510 UInt32 Trie::find_first_(T query, std::size_t *key_length) const {
    511   UInt32 node = 0;
    512   std::size_t pos = 0;
    513   do {
    514     if (terminal_flags_[node]) {
    515       if (key_length != NULL) {
    516         *key_length = pos;
    517       }
    518       return node_to_key_id(node);
    519     }
    520   } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
    521   return notfound();
    522 }
    523 
    524 template <typename T>
    525 UInt32 Trie::find_last_(T query, std::size_t *key_length) const {
    526   UInt32 node = 0;
    527   UInt32 node_found = notfound();
    528   std::size_t pos = 0;
    529   std::size_t pos_found = mismatch();
    530   do {
    531     if (terminal_flags_[node]) {
    532       node_found = node;
    533       pos_found = pos;
    534     }
    535   } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
    536   if (node_found != notfound()) {
    537     if (key_length != NULL) {
    538       *key_length = pos_found;
    539     }
    540     return node_to_key_id(node_found);
    541   }
    542   return notfound();
    543 }
    544 
    545 template <typename T, typename U, typename V>
    546 std::size_t Trie::predict_breadth_first_(T query, U key_ids, V keys,
    547     std::size_t max_num_results) const try {
    548   if (max_num_results == 0) {
    549     return 0;
    550   }
    551   UInt32 node = 0;
    552   std::size_t pos = 0;
    553   while (!query.ends_at(pos)) {
    554     if (!predict_child<T>(node, query, pos, NULL)) {
    555       return 0;
    556     }
    557   }
    558   std::string key;
    559   std::size_t count = 0;
    560   if (terminal_flags_[node]) {
    561     const UInt32 key_id = node_to_key_id(node);
    562     if (key_ids.is_valid()) {
    563       key_ids.insert(count, key_id);
    564     }
    565     if (keys.is_valid()) {
    566       restore(key_id, &key);
    567       keys.insert(count, key);
    568     }
    569     if (++count >= max_num_results) {
    570       return count;
    571     }
    572   }
    573   const UInt32 louds_pos = get_child(node);
    574   if (!louds_[louds_pos]) {
    575     return count;
    576   }
    577   UInt32 node_begin = louds_pos_to_node(louds_pos, node);
    578   UInt32 node_end = louds_pos_to_node(get_child(node + 1), node + 1);
    579   while (node_begin < node_end) {
    580     const UInt32 key_id_begin = node_to_key_id(node_begin);
    581     const UInt32 key_id_end = node_to_key_id(node_end);
    582     if (key_ids.is_valid()) {
    583       UInt32 temp_count = count;
    584       for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
    585         key_ids.insert(temp_count, key_id);
    586         if (++temp_count >= max_num_results) {
    587           break;
    588         }
    589       }
    590     }
    591     if (keys.is_valid()) {
    592       UInt32 temp_count = count;
    593       for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
    594         key.clear();
    595         restore(key_id, &key);
    596         keys.insert(temp_count, key);
    597         if (++temp_count >= max_num_results) {
    598           break;
    599         }
    600       }
    601     }
    602     count += key_id_end - key_id_begin;
    603     if (count >= max_num_results) {
    604       return max_num_results;
    605     }
    606     node_begin = louds_pos_to_node(get_child(node_begin), node_begin);
    607     node_end = louds_pos_to_node(get_child(node_end), node_end);
    608   }
    609   return count;
    610 } catch (const std::bad_alloc &) {
    611   MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
    612 } catch (const std::length_error &) {
    613   MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
    614 }
    615 
    616 template <typename T, typename U, typename V>
    617 std::size_t Trie::predict_depth_first_(T query, U key_ids, V keys,
    618     std::size_t max_num_results) const try {
    619   if (max_num_results == 0) {
    620     return 0;
    621   } else if (keys.is_valid()) {
    622     PredictCallback<U, V> callback(key_ids, keys, max_num_results);
    623     return predict_callback_(query, callback);
    624   }
    625 
    626   UInt32 node = 0;
    627   std::size_t pos = 0;
    628   while (!query.ends_at(pos)) {
    629     if (!predict_child<T>(node, query, pos, NULL)) {
    630       return 0;
    631     }
    632   }
    633   std::size_t count = 0;
    634   if (terminal_flags_[node]) {
    635     if (key_ids.is_valid()) {
    636       key_ids.insert(count, node_to_key_id(node));
    637     }
    638     if (++count >= max_num_results) {
    639       return count;
    640     }
    641   }
    642   Cell cell;
    643   cell.set_louds_pos(get_child(node));
    644   if (!louds_[cell.louds_pos()]) {
    645     return count;
    646   }
    647   cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
    648   cell.set_key_id(node_to_key_id(cell.node()));
    649   Vector<Cell> stack;
    650   stack.push_back(cell);
    651   std::size_t stack_pos = 1;
    652   while (stack_pos != 0) {
    653     Cell &cur = stack[stack_pos - 1];
    654     if (!louds_[cur.louds_pos()]) {
    655       cur.set_louds_pos(cur.louds_pos() + 1);
    656       --stack_pos;
    657       continue;
    658     }
    659     cur.set_louds_pos(cur.louds_pos() + 1);
    660     if (terminal_flags_[cur.node()]) {
    661       if (key_ids.is_valid()) {
    662         key_ids.insert(count, cur.key_id());
    663       }
    664       if (++count >= max_num_results) {
    665         return count;
    666       }
    667       cur.set_key_id(cur.key_id() + 1);
    668     }
    669     if (stack_pos == stack.size()) {
    670       cell.set_louds_pos(get_child(cur.node()));
    671       cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
    672       cell.set_key_id(node_to_key_id(cell.node()));
    673       stack.push_back(cell);
    674     }
    675     stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
    676     ++stack_pos;
    677   }
    678   return count;
    679 } catch (const std::bad_alloc &) {
    680   MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
    681 } catch (const std::length_error &) {
    682   MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
    683 }
    684 
    685 template <typename T>
    686 std::size_t Trie::trie_prefix_match(UInt32 node, T query,
    687     std::size_t pos, std::string *key) const {
    688   if (has_link(node)) {
    689     std::size_t next_pos;
    690     if (has_trie()) {
    691       next_pos = trie_->trie_prefix_match<T>(get_link(node), query, pos, key);
    692     } else {
    693       next_pos = tail_prefix_match<T>(
    694           node, get_link_id(node), query, pos, key);
    695     }
    696     if ((next_pos == mismatch()) || (next_pos == pos)) {
    697       return next_pos;
    698     }
    699     pos = next_pos;
    700   } else if (labels_[node] != query[pos]) {
    701     return pos;
    702   } else {
    703     ++pos;
    704   }
    705   node = get_parent(node);
    706   while (node != 0) {
    707     if (query.ends_at(pos)) {
    708       if (key != NULL) {
    709         trie_restore(node, key);
    710       }
    711       return pos;
    712     }
    713     if (has_link(node)) {
    714       std::size_t next_pos;
    715       if (has_trie()) {
    716         next_pos = trie_->trie_prefix_match<T>(
    717             get_link(node), query, pos, key);
    718       } else {
    719         next_pos = tail_prefix_match<T>(
    720             node, get_link_id(node), query, pos, key);
    721       }
    722       if ((next_pos == mismatch()) || (next_pos == pos)) {
    723         return next_pos;
    724       }
    725       pos = next_pos;
    726     } else if (labels_[node] != query[pos]) {
    727       return mismatch();
    728     } else {
    729       ++pos;
    730     }
    731     node = get_parent(node);
    732   }
    733   return pos;
    734 }
    735 
    736 template std::size_t Trie::trie_prefix_match<CQuery>(UInt32 node,
    737     CQuery query, std::size_t pos, std::string *key) const;
    738 template std::size_t Trie::trie_prefix_match<const Query &>(UInt32 node,
    739     const Query &query, std::size_t pos, std::string *key) const;
    740 
    741 template <typename T>
    742 std::size_t Trie::tail_prefix_match(UInt32 node, UInt32 link_id,
    743     T query, std::size_t pos, std::string *key) const {
    744   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
    745   const UInt8 *ptr = tail_[offset];
    746   if (*ptr != query[pos]) {
    747     return pos;
    748   } else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
    749     const UInt32 length = (links_[link_id + 1] * 256)
    750         + labels_[link_flags_.select1(link_id + 1)] - offset;
    751     for (UInt32 i = 1; i < length; ++i) {
    752       if (query.ends_at(pos + i)) {
    753         if (key != NULL) {
    754           key->append(reinterpret_cast<const char *>(ptr + i), length - i);
    755         }
    756         return pos + i;
    757       } else if (ptr[i] != query[pos + i]) {
    758         return mismatch();
    759       }
    760     }
    761     return pos + length;
    762   } else {
    763     for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
    764       if (query.ends_at(pos)) {
    765         if (key != NULL) {
    766           key->append(reinterpret_cast<const char *>(ptr));
    767         }
    768         return pos;
    769       } else if (*ptr != query[pos]) {
    770         return mismatch();
    771       }
    772     }
    773     return pos;
    774   }
    775 }
    776 
    777 template std::size_t Trie::tail_prefix_match<CQuery>(
    778     UInt32 node, UInt32 link_id,
    779     CQuery query, std::size_t pos, std::string *key) const;
    780 template std::size_t Trie::tail_prefix_match<const Query &>(
    781     UInt32 node, UInt32 link_id,
    782     const Query &query, std::size_t pos, std::string *key) const;
    783 
    784 }  // namespace marisa_alpha
    785