1 #ifndef MARISA_TRIE_INLINE_H_ 2 #define MARISA_TRIE_INLINE_H_ 3 4 #include <stdexcept> 5 6 #include "cell.h" 7 8 namespace marisa { 9 10 inline std::string Trie::operator[](UInt32 key_id) const { 11 std::string key; 12 restore(key_id, &key); 13 return key; 14 } 15 16 inline UInt32 Trie::operator[](const char *str) const { 17 return lookup(str); 18 } 19 20 inline UInt32 Trie::operator[](const std::string &str) const { 21 return lookup(str); 22 } 23 24 inline UInt32 Trie::lookup(const std::string &str) const { 25 return lookup(str.c_str(), str.length()); 26 } 27 28 inline std::size_t Trie::find(const std::string &str, 29 UInt32 *key_ids, std::size_t *key_lengths, 30 std::size_t max_num_results) const { 31 return find(str.c_str(), str.length(), 32 key_ids, key_lengths, max_num_results); 33 } 34 35 inline std::size_t Trie::find(const std::string &str, 36 std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths, 37 std::size_t max_num_results) const { 38 return find(str.c_str(), str.length(), 39 key_ids, key_lengths, max_num_results); 40 } 41 42 inline UInt32 Trie::find_first(const std::string &str, 43 std::size_t *key_length) const { 44 return find_first(str.c_str(), str.length(), key_length); 45 } 46 47 inline UInt32 Trie::find_last(const std::string &str, 48 std::size_t *key_length) const { 49 return find_last(str.c_str(), str.length(), key_length); 50 } 51 52 template <typename T> 53 inline std::size_t Trie::find_callback(const char *str, 54 T callback) const { 55 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 56 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR); 57 return find_callback_<CQuery>(CQuery(str), callback); 58 } 59 60 template <typename T> 61 inline std::size_t Trie::find_callback(const char *ptr, std::size_t length, 62 T callback) const { 63 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 64 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR); 65 return find_callback_<const Query &>(Query(ptr, length), callback); 66 } 67 68 template <typename T> 69 inline std::size_t Trie::find_callback(const std::string &str, 70 T callback) const { 71 return find_callback(str.c_str(), str.length(), callback); 72 } 73 74 inline std::size_t Trie::predict(const std::string &str, 75 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 76 return predict(str.c_str(), str.length(), key_ids, keys, max_num_results); 77 } 78 79 inline std::size_t Trie::predict(const std::string &str, 80 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 81 std::size_t max_num_results) const { 82 return predict(str.c_str(), str.length(), key_ids, keys, max_num_results); 83 } 84 85 inline std::size_t Trie::predict_breadth_first(const std::string &str, 86 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 87 return predict_breadth_first(str.c_str(), str.length(), 88 key_ids, keys, max_num_results); 89 } 90 91 inline std::size_t Trie::predict_breadth_first(const std::string &str, 92 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 93 std::size_t max_num_results) const { 94 return predict_breadth_first(str.c_str(), str.length(), 95 key_ids, keys, max_num_results); 96 } 97 98 inline std::size_t Trie::predict_depth_first(const std::string &str, 99 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 100 return predict_depth_first(str.c_str(), str.length(), 101 key_ids, keys, max_num_results); 102 } 103 104 inline std::size_t Trie::predict_depth_first(const std::string &str, 105 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 106 std::size_t max_num_results) const { 107 return predict_depth_first(str.c_str(), str.length(), 108 key_ids, keys, max_num_results); 109 } 110 111 template <typename T> 112 inline std::size_t Trie::predict_callback( 113 const char *str, T callback) const { 114 return predict_callback_<CQuery>(CQuery(str), callback); 115 } 116 117 template <typename T> 118 inline std::size_t Trie::predict_callback( 119 const char *ptr, std::size_t length, 120 T callback) const { 121 return predict_callback_<const Query &>(Query(ptr, length), callback); 122 } 123 124 template <typename T> 125 inline std::size_t Trie::predict_callback( 126 const std::string &str, T callback) const { 127 return predict_callback(str.c_str(), str.length(), callback); 128 } 129 130 inline bool Trie::empty() const { 131 return louds_.empty(); 132 } 133 134 inline std::size_t Trie::num_keys() const { 135 return num_keys_; 136 } 137 138 inline UInt32 Trie::notfound() { 139 return MARISA_NOT_FOUND; 140 } 141 142 inline std::size_t Trie::mismatch() { 143 return MARISA_MISMATCH; 144 } 145 146 template <typename T> 147 inline bool Trie::find_child(UInt32 &node, T query, 148 std::size_t &pos) const { 149 UInt32 louds_pos = get_child(node); 150 if (!louds_[louds_pos]) { 151 return false; 152 } 153 node = louds_pos_to_node(louds_pos, node); 154 UInt32 link_id = MARISA_UINT32_MAX; 155 do { 156 if (has_link(node)) { 157 if (link_id == MARISA_UINT32_MAX) { 158 link_id = get_link_id(node); 159 } else { 160 ++link_id; 161 } 162 std::size_t next_pos = has_trie() ? 163 trie_->trie_match<T>(get_link(node, link_id), query, pos) : 164 tail_match<T>(node, link_id, query, pos); 165 if (next_pos == mismatch()) { 166 return false; 167 } else if (next_pos != pos) { 168 pos = next_pos; 169 return true; 170 } 171 } else if (labels_[node] == query[pos]) { 172 ++pos; 173 return true; 174 } 175 ++node; 176 ++louds_pos; 177 } while (louds_[louds_pos]); 178 return false; 179 } 180 181 template <typename T, typename U> 182 std::size_t Trie::find_callback_(T query, U callback) const { 183 std::size_t count = 0; 184 UInt32 node = 0; 185 std::size_t pos = 0; 186 do { 187 if (terminal_flags_[node]) { 188 ++count; 189 if (!callback(node_to_key_id(node), pos)) { 190 return count; 191 } 192 } 193 } while (!query.ends_at(pos) && find_child<T>(node, query, pos)); 194 return count; 195 } 196 197 template <typename T> 198 inline bool Trie::predict_child(UInt32 &node, T query, std::size_t &pos, 199 std::string *key) const { 200 UInt32 louds_pos = get_child(node); 201 if (!louds_[louds_pos]) { 202 return false; 203 } 204 node = louds_pos_to_node(louds_pos, node); 205 UInt32 link_id = MARISA_UINT32_MAX; 206 do { 207 if (has_link(node)) { 208 if (link_id == MARISA_UINT32_MAX) { 209 link_id = get_link_id(node); 210 } else { 211 ++link_id; 212 } 213 std::size_t next_pos = has_trie() ? 214 trie_->trie_prefix_match<T>( 215 get_link(node, link_id), query, pos, key) : 216 tail_prefix_match<T>(node, link_id, query, pos, key); 217 if (next_pos == mismatch()) { 218 return false; 219 } else if (next_pos != pos) { 220 pos = next_pos; 221 return true; 222 } 223 } else if (labels_[node] == query[pos]) { 224 ++pos; 225 return true; 226 } 227 ++node; 228 ++louds_pos; 229 } while (louds_[louds_pos]); 230 return false; 231 } 232 233 template <typename T, typename U> 234 std::size_t Trie::predict_callback_(T query, U callback) const { 235 std::string key; 236 UInt32 node = 0; 237 std::size_t pos = 0; 238 while (!query.ends_at(pos)) { 239 if (!predict_child<T>(node, query, pos, &key)) { 240 return 0; 241 } 242 } 243 query.insert(&key); 244 std::size_t count = 0; 245 if (terminal_flags_[node]) { 246 ++count; 247 if (!callback(node_to_key_id(node), key)) { 248 return count; 249 } 250 } 251 Cell cell; 252 cell.set_louds_pos(get_child(node)); 253 if (!louds_[cell.louds_pos()]) { 254 return count; 255 } 256 cell.set_node(louds_pos_to_node(cell.louds_pos(), node)); 257 cell.set_key_id(node_to_key_id(cell.node())); 258 cell.set_length(key.length()); 259 Vector<Cell> stack; 260 stack.push_back(cell); 261 std::size_t stack_pos = 1; 262 while (stack_pos != 0) { 263 Cell &cur = stack[stack_pos - 1]; 264 if (!louds_[cur.louds_pos()]) { 265 cur.set_louds_pos(cur.louds_pos() + 1); 266 --stack_pos; 267 continue; 268 } 269 cur.set_louds_pos(cur.louds_pos() + 1); 270 key.resize(cur.length()); 271 if (has_link(cur.node())) { 272 if (has_trie()) { 273 trie_->trie_restore(get_link(cur.node()), &key); 274 } else { 275 tail_restore(cur.node(), &key); 276 } 277 } else { 278 key += labels_[cur.node()]; 279 } 280 if (terminal_flags_[cur.node()]) { 281 ++count; 282 if (!callback(cur.key_id(), key)) { 283 return count; 284 } 285 cur.set_key_id(cur.key_id() + 1); 286 } 287 if (stack_pos == stack.size()) { 288 cell.set_louds_pos(get_child(cur.node())); 289 cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node())); 290 cell.set_key_id(node_to_key_id(cell.node())); 291 stack.push_back(cell); 292 } 293 stack[stack_pos].set_length(key.length()); 294 stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1); 295 ++stack_pos; 296 } 297 return count; 298 } 299 300 inline UInt32 Trie::key_id_to_node(UInt32 key_id) const { 301 return terminal_flags_.select1(key_id); 302 } 303 304 inline UInt32 Trie::node_to_key_id(UInt32 node) const { 305 return terminal_flags_.rank1(node); 306 } 307 308 inline UInt32 Trie::louds_pos_to_node(UInt32 louds_pos, 309 UInt32 parent_node) const { 310 return louds_pos - parent_node - 1; 311 } 312 313 inline UInt32 Trie::get_child(UInt32 node) const { 314 return louds_.select0(node) + 1; 315 } 316 317 inline UInt32 Trie::get_parent(UInt32 node) const { 318 return (node > num_first_branches_) ? (louds_.select1(node) - node - 1) : 0; 319 } 320 321 inline bool Trie::has_link(UInt32 node) const { 322 return (link_flags_.empty()) ? false : link_flags_[node]; 323 } 324 325 inline UInt32 Trie::get_link_id(UInt32 node) const { 326 return link_flags_.rank1(node); 327 } 328 329 inline UInt32 Trie::get_link(UInt32 node) const { 330 return get_link(node, get_link_id(node)); 331 } 332 333 inline UInt32 Trie::get_link(UInt32 node, UInt32 link_id) const { 334 return (links_[link_id] * 256) + labels_[node]; 335 } 336 337 inline bool Trie::has_link() const { 338 return !link_flags_.empty(); 339 } 340 341 inline bool Trie::has_trie() const { 342 return trie_.get() != NULL; 343 } 344 345 inline bool Trie::has_tail() const { 346 return !tail_.empty(); 347 } 348 349 } // namespace marisa 350 351 #endif // MARISA_TRIE_INLINE_H_ 352