Home | History | Annotate | Download | only in marisa
      1 #include <algorithm>
      2 #include <stdexcept>
      3 
      4 #include "trie.h"
      5 
      6 namespace marisa {
      7 namespace {
      8 
      9 template <typename T, typename U>
     10 class PredictCallback {
     11  public:
     12   PredictCallback(T key_ids, U keys, std::size_t max_num_results)
     13       : key_ids_(key_ids), keys_(keys),
     14         max_num_results_(max_num_results), num_results_(0) {}
     15   PredictCallback(const PredictCallback &callback)
     16       : key_ids_(callback.key_ids_), keys_(callback.keys_),
     17         max_num_results_(callback.max_num_results_),
     18         num_results_(callback.num_results_) {}
     19 
     20   bool operator()(marisa::UInt32 key_id, const std::string &key) {
     21     if (key_ids_.is_valid()) {
     22       key_ids_.insert(num_results_, key_id);
     23     }
     24     if (keys_.is_valid()) {
     25       keys_.insert(num_results_, key);
     26     }
     27     return ++num_results_ < max_num_results_;
     28   }
     29 
     30  private:
     31   T key_ids_;
     32   U keys_;
     33   const std::size_t max_num_results_;
     34   std::size_t num_results_;
     35 
     36   // Disallows assignment.
     37   PredictCallback &operator=(const PredictCallback &);
     38 };
     39 
     40 }  // namespace
     41 
     42 std::string Trie::restore(UInt32 key_id) const {
     43   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
     44   MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
     45   std::string key;
     46   restore_(key_id, &key);
     47   return key;
     48 }
     49 
     50 void Trie::restore(UInt32 key_id, std::string *key) const {
     51   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
     52   MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
     53   MARISA_THROW_IF(key == NULL, MARISA_PARAM_ERROR);
     54   restore_(key_id, key);
     55 }
     56 
     57 std::size_t Trie::restore(UInt32 key_id, char *key_buf,
     58     std::size_t key_buf_size) const {
     59   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
     60   MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
     61   MARISA_THROW_IF((key_buf == NULL) && (key_buf_size != 0),
     62       MARISA_PARAM_ERROR);
     63   return restore_(key_id, key_buf, key_buf_size);
     64 }
     65 
     66 UInt32 Trie::lookup(const char *str) const {
     67   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
     68   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
     69   return lookup_<CQuery>(CQuery(str));
     70 }
     71 
     72 UInt32 Trie::lookup(const char *ptr, std::size_t length) const {
     73   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
     74   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
     75   return lookup_<const Query &>(Query(ptr, length));
     76 }
     77 
     78 std::size_t Trie::find(const char *str,
     79     UInt32 *key_ids, std::size_t *key_lengths,
     80     std::size_t max_num_results) const {
     81   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
     82   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
     83   return find_<CQuery>(CQuery(str),
     84       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
     85 }
     86 
     87 std::size_t Trie::find(const char *ptr, std::size_t length,
     88     UInt32 *key_ids, std::size_t *key_lengths,
     89     std::size_t max_num_results) const {
     90   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
     91   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
     92   return find_<const Query &>(Query(ptr, length),
     93       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
     94 }
     95 
     96 std::size_t Trie::find(const char *str,
     97     std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
     98     std::size_t max_num_results) const {
     99   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    100   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
    101   return find_<CQuery>(CQuery(str),
    102       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
    103 }
    104 
    105 std::size_t Trie::find(const char *ptr, std::size_t length,
    106     std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
    107     std::size_t max_num_results) const {
    108   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    109   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
    110   return find_<const Query &>(Query(ptr, length),
    111       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
    112 }
    113 
    114 UInt32 Trie::find_first(const char *str,
    115     std::size_t *key_length) const {
    116   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    117   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
    118   return find_first_<CQuery>(CQuery(str), key_length);
    119 }
    120 
    121 UInt32 Trie::find_first(const char *ptr, std::size_t length,
    122     std::size_t *key_length) const {
    123   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    124   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
    125   return find_first_<const Query &>(Query(ptr, length), key_length);
    126 }
    127 
    128 UInt32 Trie::find_last(const char *str,
    129     std::size_t *key_length) const {
    130   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    131   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
    132   return find_last_<CQuery>(CQuery(str), key_length);
    133 }
    134 
    135 UInt32 Trie::find_last(const char *ptr, std::size_t length,
    136     std::size_t *key_length) const {
    137   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    138   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
    139   return find_last_<const Query &>(Query(ptr, length), key_length);
    140 }
    141 
    142 std::size_t Trie::predict(const char *str,
    143     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
    144   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    145   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
    146   return (keys == NULL) ?
    147       predict_breadth_first(str, key_ids, keys, max_num_results) :
    148       predict_depth_first(str, key_ids, keys, max_num_results);
    149 }
    150 
    151 std::size_t Trie::predict(const char *ptr, std::size_t length,
    152     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
    153   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    154   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
    155   return (keys == NULL) ?
    156       predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
    157       predict_depth_first(ptr, length, key_ids, keys, max_num_results);
    158 }
    159 
    160 std::size_t Trie::predict(const char *str,
    161     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
    162     std::size_t max_num_results) const {
    163   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    164   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
    165   return (keys == NULL) ?
    166       predict_breadth_first(str, key_ids, keys, max_num_results) :
    167       predict_depth_first(str, key_ids, keys, max_num_results);
    168 }
    169 
    170 std::size_t Trie::predict(const char *ptr, std::size_t length,
    171     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
    172     std::size_t max_num_results) const {
    173   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    174   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
    175   return (keys == NULL) ?
    176       predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
    177       predict_depth_first(ptr, length, key_ids, keys, max_num_results);
    178 }
    179 
    180 std::size_t Trie::predict_breadth_first(const char *str,
    181     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
    182   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    183   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
    184   return predict_breadth_first_<CQuery>(CQuery(str),
    185       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    186 }
    187 
    188 std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
    189     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
    190   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    191   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
    192   return predict_breadth_first_<const Query &>(Query(ptr, length),
    193       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    194 }
    195 
    196 std::size_t Trie::predict_breadth_first(const char *str,
    197     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
    198     std::size_t max_num_results) const {
    199   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    200   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
    201   return predict_breadth_first_<CQuery>(CQuery(str),
    202       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    203 }
    204 
    205 std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
    206     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
    207     std::size_t max_num_results) const {
    208   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    209   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
    210   return predict_breadth_first_<const Query &>(Query(ptr, length),
    211       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    212 }
    213 
    214 std::size_t Trie::predict_depth_first(const char *str,
    215     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
    216   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    217   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
    218   return predict_depth_first_<CQuery>(CQuery(str),
    219       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    220 }
    221 
    222 std::size_t Trie::predict_depth_first(const char *ptr, std::size_t length,
    223     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
    224   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    225   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
    226   return predict_depth_first_<const Query &>(Query(ptr, length),
    227       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    228 }
    229 
    230 std::size_t Trie::predict_depth_first(
    231     const char *str, std::vector<UInt32> *key_ids,
    232     std::vector<std::string> *keys, std::size_t max_num_results) const {
    233   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    234   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
    235   return predict_depth_first_<CQuery>(CQuery(str),
    236       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    237 }
    238 
    239 std::size_t Trie::predict_depth_first(
    240     const char *ptr, std::size_t length, std::vector<UInt32> *key_ids,
    241     std::vector<std::string> *keys, std::size_t max_num_results) const {
    242   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
    243   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
    244   return predict_depth_first_<const Query &>(Query(ptr, length),
    245       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
    246 }
    247 
    248 void Trie::restore_(UInt32 key_id, std::string *key) const {
    249   const std::size_t start_pos = key->length();
    250   UInt32 node = key_id_to_node(key_id);
    251   while (node != 0) {
    252     if (has_link(node)) {
    253       const std::size_t prev_pos = key->length();
    254       if (has_trie()) {
    255         trie_->trie_restore(get_link(node), key);
    256       } else {
    257         tail_restore(node, key);
    258       }
    259       std::reverse(key->begin() + prev_pos, key->end());
    260     } else {
    261       *key += labels_[node];
    262     }
    263     node = get_parent(node);
    264   }
    265   std::reverse(key->begin() + start_pos, key->end());
    266 }
    267 
    268 void Trie::trie_restore(UInt32 node, std::string *key) const {
    269   do {
    270     if (has_link(node)) {
    271       if (has_trie()) {
    272         trie_->trie_restore(get_link(node), key);
    273       } else {
    274         tail_restore(node, key);
    275       }
    276     } else {
    277       *key += labels_[node];
    278     }
    279     node = get_parent(node);
    280   } while (node != 0);
    281 }
    282 
    283 void Trie::tail_restore(UInt32 node, std::string *key) const {
    284   const UInt32 link_id = link_flags_.rank1(node);
    285   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
    286   if (tail_.mode() == MARISA_BINARY_TAIL) {
    287     const UInt32 length = (links_[link_id + 1] * 256)
    288         + labels_[link_flags_.select1(link_id + 1)] - offset;
    289     key->append(reinterpret_cast<const char *>(tail_[offset]), length);
    290   } else {
    291     key->append(reinterpret_cast<const char *>(tail_[offset]));
    292   }
    293 }
    294 
    295 std::size_t Trie::restore_(UInt32 key_id, char *key_buf,
    296     std::size_t key_buf_size) const {
    297   std::size_t pos = 0;
    298   UInt32 node = key_id_to_node(key_id);
    299   while (node != 0) {
    300     if (has_link(node)) {
    301       const std::size_t prev_pos = pos;
    302       if (has_trie()) {
    303         trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
    304       } else {
    305         tail_restore(node, key_buf, key_buf_size, pos);
    306       }
    307       if (pos < key_buf_size) {
    308         std::reverse(key_buf + prev_pos, key_buf + pos);
    309       }
    310     } else {
    311       if (pos < key_buf_size) {
    312         key_buf[pos] = labels_[node];
    313       }
    314       ++pos;
    315     }
    316     node = get_parent(node);
    317   }
    318   if (pos < key_buf_size) {
    319     key_buf[pos] = '\0';
    320     std::reverse(key_buf, key_buf + pos);
    321   }
    322   return pos;
    323 }
    324 
    325 void Trie::trie_restore(UInt32 node, char *key_buf,
    326     std::size_t key_buf_size, std::size_t &pos) const {
    327   do {
    328     if (has_link(node)) {
    329       if (has_trie()) {
    330         trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
    331       } else {
    332         tail_restore(node, key_buf, key_buf_size, pos);
    333       }
    334     } else {
    335       if (pos < key_buf_size) {
    336         key_buf[pos] = labels_[node];
    337       }
    338       ++pos;
    339     }
    340     node = get_parent(node);
    341   } while (node != 0);
    342 }
    343 
    344 void Trie::tail_restore(UInt32 node, char *key_buf,
    345     std::size_t key_buf_size, std::size_t &pos) const {
    346   const UInt32 link_id = link_flags_.rank1(node);
    347   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
    348   if (tail_.mode() == MARISA_BINARY_TAIL) {
    349     const UInt8 *ptr = tail_[offset];
    350     const UInt32 length = (links_[link_id + 1] * 256)
    351         + labels_[link_flags_.select1(link_id + 1)] - offset;
    352     for (UInt32 i = 0; i < length; ++i) {
    353       if (pos < key_buf_size) {
    354         key_buf[pos] = ptr[i];
    355       }
    356       ++pos;
    357     }
    358   } else {
    359     for (const UInt8 *str = tail_[offset]; *str != '\0'; ++str) {
    360       if (pos < key_buf_size) {
    361         key_buf[pos] = *str;
    362       }
    363       ++pos;
    364     }
    365   }
    366 }
    367 
    368 template <typename T>
    369 UInt32 Trie::lookup_(T query) const {
    370   UInt32 node = 0;
    371   std::size_t pos = 0;
    372   while (!query.ends_at(pos)) {
    373     if (!find_child<T>(node, query, pos)) {
    374       return notfound();
    375     }
    376   }
    377   return terminal_flags_[node] ? node_to_key_id(node) : notfound();
    378 }
    379 
    380 template <typename T>
    381 std::size_t Trie::trie_match(UInt32 node, T query,
    382     std::size_t pos) const {
    383   if (has_link(node)) {
    384     std::size_t next_pos;
    385     if (has_trie()) {
    386       next_pos = trie_->trie_match<T>(get_link(node), query, pos);
    387     } else {
    388       next_pos = tail_match<T>(node, get_link_id(node), query, pos);
    389     }
    390     if ((next_pos == mismatch()) || (next_pos == pos)) {
    391       return next_pos;
    392     }
    393     pos = next_pos;
    394   } else if (labels_[node] != query[pos]) {
    395     return pos;
    396   } else {
    397     ++pos;
    398   }
    399   node = get_parent(node);
    400   while (node != 0) {
    401     if (query.ends_at(pos)) {
    402       return mismatch();
    403     }
    404     if (has_link(node)) {
    405       std::size_t next_pos;
    406       if (has_trie()) {
    407         next_pos = trie_->trie_match<T>(get_link(node), query, pos);
    408       } else {
    409         next_pos = tail_match<T>(node, get_link_id(node), query, pos);
    410       }
    411       if ((next_pos == mismatch()) || (next_pos == pos)) {
    412         return mismatch();
    413       }
    414       pos = next_pos;
    415     } else if (labels_[node] != query[pos]) {
    416       return mismatch();
    417     } else {
    418       ++pos;
    419     }
    420     node = get_parent(node);
    421   }
    422   return pos;
    423 }
    424 
    425 template std::size_t Trie::trie_match<CQuery>(UInt32 node,
    426     CQuery query, std::size_t pos) const;
    427 template std::size_t Trie::trie_match<const Query &>(UInt32 node,
    428     const Query &query, std::size_t pos) const;
    429 
    430 template <typename T>
    431 std::size_t Trie::tail_match(UInt32 node, UInt32 link_id,
    432     T query, std::size_t pos) const {
    433   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
    434   const UInt8 *ptr = tail_[offset];
    435   if (*ptr != query[pos]) {
    436     return pos;
    437   } else if (tail_.mode() == MARISA_BINARY_TAIL) {
    438     const UInt32 length = (links_[link_id + 1] * 256)
    439         + labels_[link_flags_.select1(link_id + 1)] - offset;
    440     for (UInt32 i = 1; i < length; ++i) {
    441       if (query.ends_at(pos + i) || (ptr[i] != query[pos + i])) {
    442         return mismatch();
    443       }
    444     }
    445     return pos + length;
    446   } else {
    447     for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
    448       if (query.ends_at(pos) || (*ptr != query[pos])) {
    449         return mismatch();
    450       }
    451     }
    452     return pos;
    453   }
    454 }
    455 
    456 template std::size_t Trie::tail_match<CQuery>(UInt32 node,
    457     UInt32 link_id, CQuery query, std::size_t pos) const;
    458 template std::size_t Trie::tail_match<const Query &>(UInt32 node,
    459     UInt32 link_id, const Query &query, std::size_t pos) const;
    460 
    461 template <typename T, typename U, typename V>
    462 std::size_t Trie::find_(T query, U key_ids, V key_lengths,
    463     std::size_t max_num_results) const {
    464   if (max_num_results == 0) {
    465     return 0;
    466   }
    467   std::size_t count = 0;
    468   UInt32 node = 0;
    469   std::size_t pos = 0;
    470   do {
    471     if (terminal_flags_[node]) {
    472       if (key_ids.is_valid()) {
    473         key_ids.insert(count, node_to_key_id(node));
    474       }
    475       if (key_lengths.is_valid()) {
    476         key_lengths.insert(count, pos);
    477       }
    478       if (++count >= max_num_results) {
    479         return count;
    480       }
    481     }
    482   } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
    483   return count;
    484 }
    485 
    486 template <typename T>
    487 UInt32 Trie::find_first_(T query, std::size_t *key_length) const {
    488   UInt32 node = 0;
    489   std::size_t pos = 0;
    490   do {
    491     if (terminal_flags_[node]) {
    492       if (key_length != NULL) {
    493         *key_length = pos;
    494       }
    495       return node_to_key_id(node);
    496     }
    497   } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
    498   return notfound();
    499 }
    500 
    501 template <typename T>
    502 UInt32 Trie::find_last_(T query, std::size_t *key_length) const {
    503   UInt32 node = 0;
    504   UInt32 node_found = notfound();
    505   std::size_t pos = 0;
    506   std::size_t pos_found = mismatch();
    507   do {
    508     if (terminal_flags_[node]) {
    509       node_found = node;
    510       pos_found = pos;
    511     }
    512   } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
    513   if (node_found != notfound()) {
    514     if (key_length != NULL) {
    515       *key_length = pos_found;
    516     }
    517     return node_to_key_id(node_found);
    518   }
    519   return notfound();
    520 }
    521 
    522 template <typename T, typename U, typename V>
    523 std::size_t Trie::predict_breadth_first_(T query, U key_ids, V keys,
    524     std::size_t max_num_results) const {
    525   if (max_num_results == 0) {
    526     return 0;
    527   }
    528   UInt32 node = 0;
    529   std::size_t pos = 0;
    530   while (!query.ends_at(pos)) {
    531     if (!predict_child<T>(node, query, pos, NULL)) {
    532       return 0;
    533     }
    534   }
    535   std::string key;
    536   std::size_t count = 0;
    537   if (terminal_flags_[node]) {
    538     const UInt32 key_id = node_to_key_id(node);
    539     if (key_ids.is_valid()) {
    540       key_ids.insert(count, key_id);
    541     }
    542     if (keys.is_valid()) {
    543       restore(key_id, &key);
    544       keys.insert(count, key);
    545     }
    546     if (++count >= max_num_results) {
    547       return count;
    548     }
    549   }
    550   const UInt32 louds_pos = get_child(node);
    551   if (!louds_[louds_pos]) {
    552     return count;
    553   }
    554   UInt32 node_begin = louds_pos_to_node(louds_pos, node);
    555   UInt32 node_end = louds_pos_to_node(get_child(node + 1), node + 1);
    556   while (node_begin < node_end) {
    557     const UInt32 key_id_begin = node_to_key_id(node_begin);
    558     const UInt32 key_id_end = node_to_key_id(node_end);
    559     if (key_ids.is_valid()) {
    560       UInt32 temp_count = count;
    561       for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
    562         key_ids.insert(temp_count, key_id);
    563         if (++temp_count >= max_num_results) {
    564           break;
    565         }
    566       }
    567     }
    568     if (keys.is_valid()) {
    569       UInt32 temp_count = count;
    570       for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
    571         key.clear();
    572         restore(key_id, &key);
    573         keys.insert(temp_count, key);
    574         if (++temp_count >= max_num_results) {
    575           break;
    576         }
    577       }
    578     }
    579     count += key_id_end - key_id_begin;
    580     if (count >= max_num_results) {
    581       return max_num_results;
    582     }
    583     node_begin = louds_pos_to_node(get_child(node_begin), node_begin);
    584     node_end = louds_pos_to_node(get_child(node_end), node_end);
    585   }
    586   return count;
    587 }
    588 
    589 template <typename T, typename U, typename V>
    590 std::size_t Trie::predict_depth_first_(T query, U key_ids, V keys,
    591     std::size_t max_num_results) const {
    592   if (max_num_results == 0) {
    593     return 0;
    594   } else if (keys.is_valid()) {
    595     PredictCallback<U, V> callback(key_ids, keys, max_num_results);
    596     return predict_callback_(query, callback);
    597   }
    598 
    599   UInt32 node = 0;
    600   std::size_t pos = 0;
    601   while (!query.ends_at(pos)) {
    602     if (!predict_child<T>(node, query, pos, NULL)) {
    603       return 0;
    604     }
    605   }
    606   std::size_t count = 0;
    607   if (terminal_flags_[node]) {
    608     if (key_ids.is_valid()) {
    609       key_ids.insert(count, node_to_key_id(node));
    610     }
    611     if (++count >= max_num_results) {
    612       return count;
    613     }
    614   }
    615   Cell cell;
    616   cell.set_louds_pos(get_child(node));
    617   if (!louds_[cell.louds_pos()]) {
    618     return count;
    619   }
    620   cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
    621   cell.set_key_id(node_to_key_id(cell.node()));
    622   Vector<Cell> stack;
    623   stack.push_back(cell);
    624   std::size_t stack_pos = 1;
    625   while (stack_pos != 0) {
    626     Cell &cur = stack[stack_pos - 1];
    627     if (!louds_[cur.louds_pos()]) {
    628       cur.set_louds_pos(cur.louds_pos() + 1);
    629       --stack_pos;
    630       continue;
    631     }
    632     cur.set_louds_pos(cur.louds_pos() + 1);
    633     if (terminal_flags_[cur.node()]) {
    634       if (key_ids.is_valid()) {
    635         key_ids.insert(count, cur.key_id());
    636       }
    637       if (++count >= max_num_results) {
    638         return count;
    639       }
    640       cur.set_key_id(cur.key_id() + 1);
    641     }
    642     if (stack_pos == stack.size()) {
    643       cell.set_louds_pos(get_child(cur.node()));
    644       cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
    645       cell.set_key_id(node_to_key_id(cell.node()));
    646       stack.push_back(cell);
    647     }
    648     stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
    649     ++stack_pos;
    650   }
    651   return count;
    652 }
    653 
    654 template <typename T>
    655 std::size_t Trie::trie_prefix_match(UInt32 node, T query,
    656     std::size_t pos, std::string *key) const {
    657   if (has_link(node)) {
    658     std::size_t next_pos;
    659     if (has_trie()) {
    660       next_pos = trie_->trie_prefix_match<T>(get_link(node), query, pos, key);
    661     } else {
    662       next_pos = tail_prefix_match<T>(
    663           node, get_link_id(node), query, pos, key);
    664     }
    665     if ((next_pos == mismatch()) || (next_pos == pos)) {
    666       return next_pos;
    667     }
    668     pos = next_pos;
    669   } else if (labels_[node] != query[pos]) {
    670     return pos;
    671   } else {
    672     ++pos;
    673   }
    674   node = get_parent(node);
    675   while (node != 0) {
    676     if (query.ends_at(pos)) {
    677       if (key != NULL) {
    678         trie_restore(node, key);
    679       }
    680       return pos;
    681     }
    682     if (has_link(node)) {
    683       std::size_t next_pos;
    684       if (has_trie()) {
    685         next_pos = trie_->trie_prefix_match<T>(
    686             get_link(node), query, pos, key);
    687       } else {
    688         next_pos = tail_prefix_match<T>(
    689             node, get_link_id(node), query, pos, key);
    690       }
    691       if ((next_pos == mismatch()) || (next_pos == pos)) {
    692         return next_pos;
    693       }
    694       pos = next_pos;
    695     } else if (labels_[node] != query[pos]) {
    696       return mismatch();
    697     } else {
    698       ++pos;
    699     }
    700     node = get_parent(node);
    701   }
    702   return pos;
    703 }
    704 
    705 template std::size_t Trie::trie_prefix_match<CQuery>(UInt32 node,
    706     CQuery query, std::size_t pos, std::string *key) const;
    707 template std::size_t Trie::trie_prefix_match<const Query &>(UInt32 node,
    708     const Query &query, std::size_t pos, std::string *key) const;
    709 
    710 template <typename T>
    711 std::size_t Trie::tail_prefix_match(UInt32 node, UInt32 link_id,
    712     T query, std::size_t pos, std::string *key) const {
    713   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
    714   const UInt8 *ptr = tail_[offset];
    715   if (*ptr != query[pos]) {
    716     return pos;
    717   } else if (tail_.mode() == MARISA_BINARY_TAIL) {
    718     const UInt32 length = (links_[link_id + 1] * 256)
    719         + labels_[link_flags_.select1(link_id + 1)] - offset;
    720     for (UInt32 i = 1; i < length; ++i) {
    721       if (query.ends_at(pos + i)) {
    722         if (key != NULL) {
    723           key->append(reinterpret_cast<const char *>(ptr + i), length - i);
    724         }
    725         return pos + i;
    726       } else if (ptr[i] != query[pos + i]) {
    727         return mismatch();
    728       }
    729     }
    730     return pos + length;
    731   } else {
    732     for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
    733       if (query.ends_at(pos)) {
    734         if (key != NULL) {
    735           key->append(reinterpret_cast<const char *>(ptr));
    736         }
    737         return pos;
    738       } else if (*ptr != query[pos]) {
    739         return mismatch();
    740       }
    741     }
    742     return pos;
    743   }
    744 }
    745 
    746 template std::size_t Trie::tail_prefix_match<CQuery>(
    747     UInt32 node, UInt32 link_id,
    748     CQuery query, std::size_t pos, std::string *key) const;
    749 template std::size_t Trie::tail_prefix_match<const Query &>(
    750     UInt32 node, UInt32 link_id,
    751     const Query &query, std::size_t pos, std::string *key) const;
    752 
    753 }  // namespace marisa
    754