1 #include <algorithm> 2 #include <stdexcept> 3 4 #include "trie.h" 5 6 namespace marisa { 7 namespace { 8 9 template <typename T, typename U> 10 class PredictCallback { 11 public: 12 PredictCallback(T key_ids, U keys, std::size_t max_num_results) 13 : key_ids_(key_ids), keys_(keys), 14 max_num_results_(max_num_results), num_results_(0) {} 15 PredictCallback(const PredictCallback &callback) 16 : key_ids_(callback.key_ids_), keys_(callback.keys_), 17 max_num_results_(callback.max_num_results_), 18 num_results_(callback.num_results_) {} 19 20 bool operator()(marisa::UInt32 key_id, const std::string &key) { 21 if (key_ids_.is_valid()) { 22 key_ids_.insert(num_results_, key_id); 23 } 24 if (keys_.is_valid()) { 25 keys_.insert(num_results_, key); 26 } 27 return ++num_results_ < max_num_results_; 28 } 29 30 private: 31 T key_ids_; 32 U keys_; 33 const std::size_t max_num_results_; 34 std::size_t num_results_; 35 36 // Disallows assignment. 37 PredictCallback &operator=(const PredictCallback &); 38 }; 39 40 } // namespace 41 42 std::string Trie::restore(UInt32 key_id) const { 43 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 44 MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR); 45 std::string key; 46 restore_(key_id, &key); 47 return key; 48 } 49 50 void Trie::restore(UInt32 key_id, std::string *key) const { 51 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 52 MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR); 53 MARISA_THROW_IF(key == NULL, MARISA_PARAM_ERROR); 54 restore_(key_id, key); 55 } 56 57 std::size_t Trie::restore(UInt32 key_id, char *key_buf, 58 std::size_t key_buf_size) const { 59 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 60 MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR); 61 MARISA_THROW_IF((key_buf == NULL) && (key_buf_size != 0), 62 MARISA_PARAM_ERROR); 63 return restore_(key_id, key_buf, key_buf_size); 64 } 65 66 UInt32 Trie::lookup(const char *str) const { 67 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 68 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR); 69 return lookup_<CQuery>(CQuery(str)); 70 } 71 72 UInt32 Trie::lookup(const char *ptr, std::size_t length) const { 73 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 74 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR); 75 return lookup_<const Query &>(Query(ptr, length)); 76 } 77 78 std::size_t Trie::find(const char *str, 79 UInt32 *key_ids, std::size_t *key_lengths, 80 std::size_t max_num_results) const { 81 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 82 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR); 83 return find_<CQuery>(CQuery(str), 84 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results); 85 } 86 87 std::size_t Trie::find(const char *ptr, std::size_t length, 88 UInt32 *key_ids, std::size_t *key_lengths, 89 std::size_t max_num_results) const { 90 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 91 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR); 92 return find_<const Query &>(Query(ptr, length), 93 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results); 94 } 95 96 std::size_t Trie::find(const char *str, 97 std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths, 98 std::size_t max_num_results) const { 99 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 100 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR); 101 return find_<CQuery>(CQuery(str), 102 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results); 103 } 104 105 std::size_t Trie::find(const char *ptr, std::size_t length, 106 std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths, 107 std::size_t max_num_results) const { 108 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 109 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR); 110 return find_<const Query &>(Query(ptr, length), 111 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results); 112 } 113 114 UInt32 Trie::find_first(const char *str, 115 std::size_t *key_length) const { 116 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 117 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR); 118 return find_first_<CQuery>(CQuery(str), key_length); 119 } 120 121 UInt32 Trie::find_first(const char *ptr, std::size_t length, 122 std::size_t *key_length) const { 123 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 124 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR); 125 return find_first_<const Query &>(Query(ptr, length), key_length); 126 } 127 128 UInt32 Trie::find_last(const char *str, 129 std::size_t *key_length) const { 130 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 131 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR); 132 return find_last_<CQuery>(CQuery(str), key_length); 133 } 134 135 UInt32 Trie::find_last(const char *ptr, std::size_t length, 136 std::size_t *key_length) const { 137 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 138 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR); 139 return find_last_<const Query &>(Query(ptr, length), key_length); 140 } 141 142 std::size_t Trie::predict(const char *str, 143 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 144 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 145 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR); 146 return (keys == NULL) ? 147 predict_breadth_first(str, key_ids, keys, max_num_results) : 148 predict_depth_first(str, key_ids, keys, max_num_results); 149 } 150 151 std::size_t Trie::predict(const char *ptr, std::size_t length, 152 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 153 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 154 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR); 155 return (keys == NULL) ? 156 predict_breadth_first(ptr, length, key_ids, keys, max_num_results) : 157 predict_depth_first(ptr, length, key_ids, keys, max_num_results); 158 } 159 160 std::size_t Trie::predict(const char *str, 161 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 162 std::size_t max_num_results) const { 163 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 164 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR); 165 return (keys == NULL) ? 166 predict_breadth_first(str, key_ids, keys, max_num_results) : 167 predict_depth_first(str, key_ids, keys, max_num_results); 168 } 169 170 std::size_t Trie::predict(const char *ptr, std::size_t length, 171 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 172 std::size_t max_num_results) const { 173 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 174 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR); 175 return (keys == NULL) ? 176 predict_breadth_first(ptr, length, key_ids, keys, max_num_results) : 177 predict_depth_first(ptr, length, key_ids, keys, max_num_results); 178 } 179 180 std::size_t Trie::predict_breadth_first(const char *str, 181 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 182 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 183 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR); 184 return predict_breadth_first_<CQuery>(CQuery(str), 185 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 186 } 187 188 std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length, 189 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 190 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 191 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR); 192 return predict_breadth_first_<const Query &>(Query(ptr, length), 193 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 194 } 195 196 std::size_t Trie::predict_breadth_first(const char *str, 197 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 198 std::size_t max_num_results) const { 199 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 200 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR); 201 return predict_breadth_first_<CQuery>(CQuery(str), 202 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 203 } 204 205 std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length, 206 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 207 std::size_t max_num_results) const { 208 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 209 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR); 210 return predict_breadth_first_<const Query &>(Query(ptr, length), 211 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 212 } 213 214 std::size_t Trie::predict_depth_first(const char *str, 215 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 216 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 217 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR); 218 return predict_depth_first_<CQuery>(CQuery(str), 219 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 220 } 221 222 std::size_t Trie::predict_depth_first(const char *ptr, std::size_t length, 223 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 224 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 225 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR); 226 return predict_depth_first_<const Query &>(Query(ptr, length), 227 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 228 } 229 230 std::size_t Trie::predict_depth_first( 231 const char *str, std::vector<UInt32> *key_ids, 232 std::vector<std::string> *keys, std::size_t max_num_results) const { 233 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 234 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR); 235 return predict_depth_first_<CQuery>(CQuery(str), 236 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 237 } 238 239 std::size_t Trie::predict_depth_first( 240 const char *ptr, std::size_t length, std::vector<UInt32> *key_ids, 241 std::vector<std::string> *keys, std::size_t max_num_results) const { 242 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 243 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR); 244 return predict_depth_first_<const Query &>(Query(ptr, length), 245 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 246 } 247 248 void Trie::restore_(UInt32 key_id, std::string *key) const { 249 const std::size_t start_pos = key->length(); 250 UInt32 node = key_id_to_node(key_id); 251 while (node != 0) { 252 if (has_link(node)) { 253 const std::size_t prev_pos = key->length(); 254 if (has_trie()) { 255 trie_->trie_restore(get_link(node), key); 256 } else { 257 tail_restore(node, key); 258 } 259 std::reverse(key->begin() + prev_pos, key->end()); 260 } else { 261 *key += labels_[node]; 262 } 263 node = get_parent(node); 264 } 265 std::reverse(key->begin() + start_pos, key->end()); 266 } 267 268 void Trie::trie_restore(UInt32 node, std::string *key) const { 269 do { 270 if (has_link(node)) { 271 if (has_trie()) { 272 trie_->trie_restore(get_link(node), key); 273 } else { 274 tail_restore(node, key); 275 } 276 } else { 277 *key += labels_[node]; 278 } 279 node = get_parent(node); 280 } while (node != 0); 281 } 282 283 void Trie::tail_restore(UInt32 node, std::string *key) const { 284 const UInt32 link_id = link_flags_.rank1(node); 285 const UInt32 offset = (links_[link_id] * 256) + labels_[node]; 286 if (tail_.mode() == MARISA_BINARY_TAIL) { 287 const UInt32 length = (links_[link_id + 1] * 256) 288 + labels_[link_flags_.select1(link_id + 1)] - offset; 289 key->append(reinterpret_cast<const char *>(tail_[offset]), length); 290 } else { 291 key->append(reinterpret_cast<const char *>(tail_[offset])); 292 } 293 } 294 295 std::size_t Trie::restore_(UInt32 key_id, char *key_buf, 296 std::size_t key_buf_size) const { 297 std::size_t pos = 0; 298 UInt32 node = key_id_to_node(key_id); 299 while (node != 0) { 300 if (has_link(node)) { 301 const std::size_t prev_pos = pos; 302 if (has_trie()) { 303 trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos); 304 } else { 305 tail_restore(node, key_buf, key_buf_size, pos); 306 } 307 if (pos < key_buf_size) { 308 std::reverse(key_buf + prev_pos, key_buf + pos); 309 } 310 } else { 311 if (pos < key_buf_size) { 312 key_buf[pos] = labels_[node]; 313 } 314 ++pos; 315 } 316 node = get_parent(node); 317 } 318 if (pos < key_buf_size) { 319 key_buf[pos] = '\0'; 320 std::reverse(key_buf, key_buf + pos); 321 } 322 return pos; 323 } 324 325 void Trie::trie_restore(UInt32 node, char *key_buf, 326 std::size_t key_buf_size, std::size_t &pos) const { 327 do { 328 if (has_link(node)) { 329 if (has_trie()) { 330 trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos); 331 } else { 332 tail_restore(node, key_buf, key_buf_size, pos); 333 } 334 } else { 335 if (pos < key_buf_size) { 336 key_buf[pos] = labels_[node]; 337 } 338 ++pos; 339 } 340 node = get_parent(node); 341 } while (node != 0); 342 } 343 344 void Trie::tail_restore(UInt32 node, char *key_buf, 345 std::size_t key_buf_size, std::size_t &pos) const { 346 const UInt32 link_id = link_flags_.rank1(node); 347 const UInt32 offset = (links_[link_id] * 256) + labels_[node]; 348 if (tail_.mode() == MARISA_BINARY_TAIL) { 349 const UInt8 *ptr = tail_[offset]; 350 const UInt32 length = (links_[link_id + 1] * 256) 351 + labels_[link_flags_.select1(link_id + 1)] - offset; 352 for (UInt32 i = 0; i < length; ++i) { 353 if (pos < key_buf_size) { 354 key_buf[pos] = ptr[i]; 355 } 356 ++pos; 357 } 358 } else { 359 for (const UInt8 *str = tail_[offset]; *str != '\0'; ++str) { 360 if (pos < key_buf_size) { 361 key_buf[pos] = *str; 362 } 363 ++pos; 364 } 365 } 366 } 367 368 template <typename T> 369 UInt32 Trie::lookup_(T query) const { 370 UInt32 node = 0; 371 std::size_t pos = 0; 372 while (!query.ends_at(pos)) { 373 if (!find_child<T>(node, query, pos)) { 374 return notfound(); 375 } 376 } 377 return terminal_flags_[node] ? node_to_key_id(node) : notfound(); 378 } 379 380 template <typename T> 381 std::size_t Trie::trie_match(UInt32 node, T query, 382 std::size_t pos) const { 383 if (has_link(node)) { 384 std::size_t next_pos; 385 if (has_trie()) { 386 next_pos = trie_->trie_match<T>(get_link(node), query, pos); 387 } else { 388 next_pos = tail_match<T>(node, get_link_id(node), query, pos); 389 } 390 if ((next_pos == mismatch()) || (next_pos == pos)) { 391 return next_pos; 392 } 393 pos = next_pos; 394 } else if (labels_[node] != query[pos]) { 395 return pos; 396 } else { 397 ++pos; 398 } 399 node = get_parent(node); 400 while (node != 0) { 401 if (query.ends_at(pos)) { 402 return mismatch(); 403 } 404 if (has_link(node)) { 405 std::size_t next_pos; 406 if (has_trie()) { 407 next_pos = trie_->trie_match<T>(get_link(node), query, pos); 408 } else { 409 next_pos = tail_match<T>(node, get_link_id(node), query, pos); 410 } 411 if ((next_pos == mismatch()) || (next_pos == pos)) { 412 return mismatch(); 413 } 414 pos = next_pos; 415 } else if (labels_[node] != query[pos]) { 416 return mismatch(); 417 } else { 418 ++pos; 419 } 420 node = get_parent(node); 421 } 422 return pos; 423 } 424 425 template std::size_t Trie::trie_match<CQuery>(UInt32 node, 426 CQuery query, std::size_t pos) const; 427 template std::size_t Trie::trie_match<const Query &>(UInt32 node, 428 const Query &query, std::size_t pos) const; 429 430 template <typename T> 431 std::size_t Trie::tail_match(UInt32 node, UInt32 link_id, 432 T query, std::size_t pos) const { 433 const UInt32 offset = (links_[link_id] * 256) + labels_[node]; 434 const UInt8 *ptr = tail_[offset]; 435 if (*ptr != query[pos]) { 436 return pos; 437 } else if (tail_.mode() == MARISA_BINARY_TAIL) { 438 const UInt32 length = (links_[link_id + 1] * 256) 439 + labels_[link_flags_.select1(link_id + 1)] - offset; 440 for (UInt32 i = 1; i < length; ++i) { 441 if (query.ends_at(pos + i) || (ptr[i] != query[pos + i])) { 442 return mismatch(); 443 } 444 } 445 return pos + length; 446 } else { 447 for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) { 448 if (query.ends_at(pos) || (*ptr != query[pos])) { 449 return mismatch(); 450 } 451 } 452 return pos; 453 } 454 } 455 456 template std::size_t Trie::tail_match<CQuery>(UInt32 node, 457 UInt32 link_id, CQuery query, std::size_t pos) const; 458 template std::size_t Trie::tail_match<const Query &>(UInt32 node, 459 UInt32 link_id, const Query &query, std::size_t pos) const; 460 461 template <typename T, typename U, typename V> 462 std::size_t Trie::find_(T query, U key_ids, V key_lengths, 463 std::size_t max_num_results) const { 464 if (max_num_results == 0) { 465 return 0; 466 } 467 std::size_t count = 0; 468 UInt32 node = 0; 469 std::size_t pos = 0; 470 do { 471 if (terminal_flags_[node]) { 472 if (key_ids.is_valid()) { 473 key_ids.insert(count, node_to_key_id(node)); 474 } 475 if (key_lengths.is_valid()) { 476 key_lengths.insert(count, pos); 477 } 478 if (++count >= max_num_results) { 479 return count; 480 } 481 } 482 } while (!query.ends_at(pos) && find_child<T>(node, query, pos)); 483 return count; 484 } 485 486 template <typename T> 487 UInt32 Trie::find_first_(T query, std::size_t *key_length) const { 488 UInt32 node = 0; 489 std::size_t pos = 0; 490 do { 491 if (terminal_flags_[node]) { 492 if (key_length != NULL) { 493 *key_length = pos; 494 } 495 return node_to_key_id(node); 496 } 497 } while (!query.ends_at(pos) && find_child<T>(node, query, pos)); 498 return notfound(); 499 } 500 501 template <typename T> 502 UInt32 Trie::find_last_(T query, std::size_t *key_length) const { 503 UInt32 node = 0; 504 UInt32 node_found = notfound(); 505 std::size_t pos = 0; 506 std::size_t pos_found = mismatch(); 507 do { 508 if (terminal_flags_[node]) { 509 node_found = node; 510 pos_found = pos; 511 } 512 } while (!query.ends_at(pos) && find_child<T>(node, query, pos)); 513 if (node_found != notfound()) { 514 if (key_length != NULL) { 515 *key_length = pos_found; 516 } 517 return node_to_key_id(node_found); 518 } 519 return notfound(); 520 } 521 522 template <typename T, typename U, typename V> 523 std::size_t Trie::predict_breadth_first_(T query, U key_ids, V keys, 524 std::size_t max_num_results) const { 525 if (max_num_results == 0) { 526 return 0; 527 } 528 UInt32 node = 0; 529 std::size_t pos = 0; 530 while (!query.ends_at(pos)) { 531 if (!predict_child<T>(node, query, pos, NULL)) { 532 return 0; 533 } 534 } 535 std::string key; 536 std::size_t count = 0; 537 if (terminal_flags_[node]) { 538 const UInt32 key_id = node_to_key_id(node); 539 if (key_ids.is_valid()) { 540 key_ids.insert(count, key_id); 541 } 542 if (keys.is_valid()) { 543 restore(key_id, &key); 544 keys.insert(count, key); 545 } 546 if (++count >= max_num_results) { 547 return count; 548 } 549 } 550 const UInt32 louds_pos = get_child(node); 551 if (!louds_[louds_pos]) { 552 return count; 553 } 554 UInt32 node_begin = louds_pos_to_node(louds_pos, node); 555 UInt32 node_end = louds_pos_to_node(get_child(node + 1), node + 1); 556 while (node_begin < node_end) { 557 const UInt32 key_id_begin = node_to_key_id(node_begin); 558 const UInt32 key_id_end = node_to_key_id(node_end); 559 if (key_ids.is_valid()) { 560 UInt32 temp_count = count; 561 for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) { 562 key_ids.insert(temp_count, key_id); 563 if (++temp_count >= max_num_results) { 564 break; 565 } 566 } 567 } 568 if (keys.is_valid()) { 569 UInt32 temp_count = count; 570 for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) { 571 key.clear(); 572 restore(key_id, &key); 573 keys.insert(temp_count, key); 574 if (++temp_count >= max_num_results) { 575 break; 576 } 577 } 578 } 579 count += key_id_end - key_id_begin; 580 if (count >= max_num_results) { 581 return max_num_results; 582 } 583 node_begin = louds_pos_to_node(get_child(node_begin), node_begin); 584 node_end = louds_pos_to_node(get_child(node_end), node_end); 585 } 586 return count; 587 } 588 589 template <typename T, typename U, typename V> 590 std::size_t Trie::predict_depth_first_(T query, U key_ids, V keys, 591 std::size_t max_num_results) const { 592 if (max_num_results == 0) { 593 return 0; 594 } else if (keys.is_valid()) { 595 PredictCallback<U, V> callback(key_ids, keys, max_num_results); 596 return predict_callback_(query, callback); 597 } 598 599 UInt32 node = 0; 600 std::size_t pos = 0; 601 while (!query.ends_at(pos)) { 602 if (!predict_child<T>(node, query, pos, NULL)) { 603 return 0; 604 } 605 } 606 std::size_t count = 0; 607 if (terminal_flags_[node]) { 608 if (key_ids.is_valid()) { 609 key_ids.insert(count, node_to_key_id(node)); 610 } 611 if (++count >= max_num_results) { 612 return count; 613 } 614 } 615 Cell cell; 616 cell.set_louds_pos(get_child(node)); 617 if (!louds_[cell.louds_pos()]) { 618 return count; 619 } 620 cell.set_node(louds_pos_to_node(cell.louds_pos(), node)); 621 cell.set_key_id(node_to_key_id(cell.node())); 622 Vector<Cell> stack; 623 stack.push_back(cell); 624 std::size_t stack_pos = 1; 625 while (stack_pos != 0) { 626 Cell &cur = stack[stack_pos - 1]; 627 if (!louds_[cur.louds_pos()]) { 628 cur.set_louds_pos(cur.louds_pos() + 1); 629 --stack_pos; 630 continue; 631 } 632 cur.set_louds_pos(cur.louds_pos() + 1); 633 if (terminal_flags_[cur.node()]) { 634 if (key_ids.is_valid()) { 635 key_ids.insert(count, cur.key_id()); 636 } 637 if (++count >= max_num_results) { 638 return count; 639 } 640 cur.set_key_id(cur.key_id() + 1); 641 } 642 if (stack_pos == stack.size()) { 643 cell.set_louds_pos(get_child(cur.node())); 644 cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node())); 645 cell.set_key_id(node_to_key_id(cell.node())); 646 stack.push_back(cell); 647 } 648 stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1); 649 ++stack_pos; 650 } 651 return count; 652 } 653 654 template <typename T> 655 std::size_t Trie::trie_prefix_match(UInt32 node, T query, 656 std::size_t pos, std::string *key) const { 657 if (has_link(node)) { 658 std::size_t next_pos; 659 if (has_trie()) { 660 next_pos = trie_->trie_prefix_match<T>(get_link(node), query, pos, key); 661 } else { 662 next_pos = tail_prefix_match<T>( 663 node, get_link_id(node), query, pos, key); 664 } 665 if ((next_pos == mismatch()) || (next_pos == pos)) { 666 return next_pos; 667 } 668 pos = next_pos; 669 } else if (labels_[node] != query[pos]) { 670 return pos; 671 } else { 672 ++pos; 673 } 674 node = get_parent(node); 675 while (node != 0) { 676 if (query.ends_at(pos)) { 677 if (key != NULL) { 678 trie_restore(node, key); 679 } 680 return pos; 681 } 682 if (has_link(node)) { 683 std::size_t next_pos; 684 if (has_trie()) { 685 next_pos = trie_->trie_prefix_match<T>( 686 get_link(node), query, pos, key); 687 } else { 688 next_pos = tail_prefix_match<T>( 689 node, get_link_id(node), query, pos, key); 690 } 691 if ((next_pos == mismatch()) || (next_pos == pos)) { 692 return next_pos; 693 } 694 pos = next_pos; 695 } else if (labels_[node] != query[pos]) { 696 return mismatch(); 697 } else { 698 ++pos; 699 } 700 node = get_parent(node); 701 } 702 return pos; 703 } 704 705 template std::size_t Trie::trie_prefix_match<CQuery>(UInt32 node, 706 CQuery query, std::size_t pos, std::string *key) const; 707 template std::size_t Trie::trie_prefix_match<const Query &>(UInt32 node, 708 const Query &query, std::size_t pos, std::string *key) const; 709 710 template <typename T> 711 std::size_t Trie::tail_prefix_match(UInt32 node, UInt32 link_id, 712 T query, std::size_t pos, std::string *key) const { 713 const UInt32 offset = (links_[link_id] * 256) + labels_[node]; 714 const UInt8 *ptr = tail_[offset]; 715 if (*ptr != query[pos]) { 716 return pos; 717 } else if (tail_.mode() == MARISA_BINARY_TAIL) { 718 const UInt32 length = (links_[link_id + 1] * 256) 719 + labels_[link_flags_.select1(link_id + 1)] - offset; 720 for (UInt32 i = 1; i < length; ++i) { 721 if (query.ends_at(pos + i)) { 722 if (key != NULL) { 723 key->append(reinterpret_cast<const char *>(ptr + i), length - i); 724 } 725 return pos + i; 726 } else if (ptr[i] != query[pos + i]) { 727 return mismatch(); 728 } 729 } 730 return pos + length; 731 } else { 732 for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) { 733 if (query.ends_at(pos)) { 734 if (key != NULL) { 735 key->append(reinterpret_cast<const char *>(ptr)); 736 } 737 return pos; 738 } else if (*ptr != query[pos]) { 739 return mismatch(); 740 } 741 } 742 return pos; 743 } 744 } 745 746 template std::size_t Trie::tail_prefix_match<CQuery>( 747 UInt32 node, UInt32 link_id, 748 CQuery query, std::size_t pos, std::string *key) const; 749 template std::size_t Trie::tail_prefix_match<const Query &>( 750 UInt32 node, UInt32 link_id, 751 const Query &query, std::size_t pos, std::string *key) const; 752 753 } // namespace marisa 754