1 #ifndef MARISA_ALPHA_TRIE_INLINE_H_ 2 #define MARISA_ALPHA_TRIE_INLINE_H_ 3 4 #include <stdexcept> 5 6 #include "cell.h" 7 8 namespace marisa_alpha { 9 10 inline std::string Trie::operator[](UInt32 key_id) const { 11 std::string key; 12 restore(key_id, &key); 13 return key; 14 } 15 16 inline UInt32 Trie::operator[](const char *str) const { 17 return lookup(str); 18 } 19 20 inline UInt32 Trie::operator[](const std::string &str) const { 21 return lookup(str); 22 } 23 24 inline UInt32 Trie::lookup(const std::string &str) const { 25 return lookup(str.c_str(), str.length()); 26 } 27 28 inline std::size_t Trie::find(const std::string &str, 29 UInt32 *key_ids, std::size_t *key_lengths, 30 std::size_t max_num_results) const { 31 return find(str.c_str(), str.length(), 32 key_ids, key_lengths, max_num_results); 33 } 34 35 inline std::size_t Trie::find(const std::string &str, 36 std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths, 37 std::size_t max_num_results) const { 38 return find(str.c_str(), str.length(), 39 key_ids, key_lengths, max_num_results); 40 } 41 42 inline UInt32 Trie::find_first(const std::string &str, 43 std::size_t *key_length) const { 44 return find_first(str.c_str(), str.length(), key_length); 45 } 46 47 inline UInt32 Trie::find_last(const std::string &str, 48 std::size_t *key_length) const { 49 return find_last(str.c_str(), str.length(), key_length); 50 } 51 52 template <typename T> 53 inline std::size_t Trie::find_callback(const char *str, 54 T callback) const { 55 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 56 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR); 57 return find_callback_<CQuery>(CQuery(str), callback); 58 } 59 60 template <typename T> 61 inline std::size_t Trie::find_callback(const char *ptr, std::size_t length, 62 T callback) const { 63 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 64 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0), 65 MARISA_ALPHA_PARAM_ERROR); 66 return find_callback_<const Query &>(Query(ptr, length), callback); 67 } 68 69 template <typename T> 70 inline std::size_t Trie::find_callback(const std::string &str, 71 T callback) const { 72 return find_callback(str.c_str(), str.length(), callback); 73 } 74 75 inline std::size_t Trie::predict(const std::string &str, 76 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 77 return predict(str.c_str(), str.length(), key_ids, keys, max_num_results); 78 } 79 80 inline std::size_t Trie::predict(const std::string &str, 81 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 82 std::size_t max_num_results) const { 83 return predict(str.c_str(), str.length(), key_ids, keys, max_num_results); 84 } 85 86 inline std::size_t Trie::predict_breadth_first(const std::string &str, 87 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 88 return predict_breadth_first(str.c_str(), str.length(), 89 key_ids, keys, max_num_results); 90 } 91 92 inline std::size_t Trie::predict_breadth_first(const std::string &str, 93 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 94 std::size_t max_num_results) const { 95 return predict_breadth_first(str.c_str(), str.length(), 96 key_ids, keys, max_num_results); 97 } 98 99 inline std::size_t Trie::predict_depth_first(const std::string &str, 100 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 101 return predict_depth_first(str.c_str(), str.length(), 102 key_ids, keys, max_num_results); 103 } 104 105 inline std::size_t Trie::predict_depth_first(const std::string &str, 106 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 107 std::size_t max_num_results) const { 108 return predict_depth_first(str.c_str(), str.length(), 109 key_ids, keys, max_num_results); 110 } 111 112 template <typename T> 113 inline std::size_t Trie::predict_callback( 114 const char *str, T callback) const { 115 return predict_callback_<CQuery>(CQuery(str), callback); 116 } 117 118 template <typename T> 119 inline std::size_t Trie::predict_callback( 120 const char *ptr, std::size_t length, 121 T callback) const { 122 return predict_callback_<const Query &>(Query(ptr, length), callback); 123 } 124 125 template <typename T> 126 inline std::size_t Trie::predict_callback( 127 const std::string &str, T callback) const { 128 return predict_callback(str.c_str(), str.length(), callback); 129 } 130 131 inline bool Trie::empty() const { 132 return louds_.empty(); 133 } 134 135 inline std::size_t Trie::num_keys() const { 136 return num_keys_; 137 } 138 139 inline UInt32 Trie::notfound() { 140 return MARISA_ALPHA_NOT_FOUND; 141 } 142 143 inline std::size_t Trie::mismatch() { 144 return MARISA_ALPHA_MISMATCH; 145 } 146 147 template <typename T> 148 inline bool Trie::find_child(UInt32 &node, T query, 149 std::size_t &pos) const { 150 UInt32 louds_pos = get_child(node); 151 if (!louds_[louds_pos]) { 152 return false; 153 } 154 node = louds_pos_to_node(louds_pos, node); 155 UInt32 link_id = MARISA_ALPHA_UINT32_MAX; 156 do { 157 if (has_link(node)) { 158 if (link_id == MARISA_ALPHA_UINT32_MAX) { 159 link_id = get_link_id(node); 160 } else { 161 ++link_id; 162 } 163 std::size_t next_pos = has_trie() ? 164 trie_->trie_match<T>(get_link(node, link_id), query, pos) : 165 tail_match<T>(node, link_id, query, pos); 166 if (next_pos == mismatch()) { 167 return false; 168 } else if (next_pos != pos) { 169 pos = next_pos; 170 return true; 171 } 172 } else if (labels_[node] == query[pos]) { 173 ++pos; 174 return true; 175 } 176 ++node; 177 ++louds_pos; 178 } while (louds_[louds_pos]); 179 return false; 180 } 181 182 template <typename T, typename U> 183 std::size_t Trie::find_callback_(T query, U callback) const try { 184 std::size_t count = 0; 185 UInt32 node = 0; 186 std::size_t pos = 0; 187 do { 188 if (terminal_flags_[node]) { 189 ++count; 190 if (!callback(node_to_key_id(node), pos)) { 191 return count; 192 } 193 } 194 } while (!query.ends_at(pos) && find_child<T>(node, query, pos)); 195 return count; 196 } catch (const std::bad_alloc &) { 197 MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR); 198 } catch (const std::length_error &) { 199 MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR); 200 } 201 202 template <typename T> 203 inline bool Trie::predict_child(UInt32 &node, T query, std::size_t &pos, 204 std::string *key) const { 205 UInt32 louds_pos = get_child(node); 206 if (!louds_[louds_pos]) { 207 return false; 208 } 209 node = louds_pos_to_node(louds_pos, node); 210 UInt32 link_id = MARISA_ALPHA_UINT32_MAX; 211 do { 212 if (has_link(node)) { 213 if (link_id == MARISA_ALPHA_UINT32_MAX) { 214 link_id = get_link_id(node); 215 } else { 216 ++link_id; 217 } 218 std::size_t next_pos = has_trie() ? 219 trie_->trie_prefix_match<T>( 220 get_link(node, link_id), query, pos, key) : 221 tail_prefix_match<T>(node, link_id, query, pos, key); 222 if (next_pos == mismatch()) { 223 return false; 224 } else if (next_pos != pos) { 225 pos = next_pos; 226 return true; 227 } 228 } else if (labels_[node] == query[pos]) { 229 ++pos; 230 return true; 231 } 232 ++node; 233 ++louds_pos; 234 } while (louds_[louds_pos]); 235 return false; 236 } 237 238 template <typename T, typename U> 239 std::size_t Trie::predict_callback_(T query, U callback) const try { 240 std::string key; 241 UInt32 node = 0; 242 std::size_t pos = 0; 243 while (!query.ends_at(pos)) { 244 if (!predict_child<T>(node, query, pos, &key)) { 245 return 0; 246 } 247 } 248 query.insert(&key); 249 std::size_t count = 0; 250 if (terminal_flags_[node]) { 251 ++count; 252 if (!callback(node_to_key_id(node), key)) { 253 return count; 254 } 255 } 256 Cell cell; 257 cell.set_louds_pos(get_child(node)); 258 if (!louds_[cell.louds_pos()]) { 259 return count; 260 } 261 cell.set_node(louds_pos_to_node(cell.louds_pos(), node)); 262 cell.set_key_id(node_to_key_id(cell.node())); 263 cell.set_length(key.length()); 264 Vector<Cell> stack; 265 stack.push_back(cell); 266 std::size_t stack_pos = 1; 267 while (stack_pos != 0) { 268 Cell &cur = stack[stack_pos - 1]; 269 if (!louds_[cur.louds_pos()]) { 270 cur.set_louds_pos(cur.louds_pos() + 1); 271 --stack_pos; 272 continue; 273 } 274 cur.set_louds_pos(cur.louds_pos() + 1); 275 key.resize(cur.length()); 276 if (has_link(cur.node())) { 277 if (has_trie()) { 278 trie_->trie_restore(get_link(cur.node()), &key); 279 } else { 280 tail_restore(cur.node(), &key); 281 } 282 } else { 283 key += labels_[cur.node()]; 284 } 285 if (terminal_flags_[cur.node()]) { 286 ++count; 287 if (!callback(cur.key_id(), key)) { 288 return count; 289 } 290 cur.set_key_id(cur.key_id() + 1); 291 } 292 if (stack_pos == stack.size()) { 293 cell.set_louds_pos(get_child(cur.node())); 294 cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node())); 295 cell.set_key_id(node_to_key_id(cell.node())); 296 stack.push_back(cell); 297 } 298 stack[stack_pos].set_length(key.length()); 299 stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1); 300 ++stack_pos; 301 } 302 return count; 303 } catch (const std::bad_alloc &) { 304 MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR); 305 } catch (const std::length_error &) { 306 MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR); 307 } 308 309 inline UInt32 Trie::key_id_to_node(UInt32 key_id) const { 310 return terminal_flags_.select1(key_id); 311 } 312 313 inline UInt32 Trie::node_to_key_id(UInt32 node) const { 314 return terminal_flags_.rank1(node); 315 } 316 317 inline UInt32 Trie::louds_pos_to_node(UInt32 louds_pos, 318 UInt32 parent_node) const { 319 return louds_pos - parent_node - 1; 320 } 321 322 inline UInt32 Trie::get_child(UInt32 node) const { 323 return louds_.select0(node) + 1; 324 } 325 326 inline UInt32 Trie::get_parent(UInt32 node) const { 327 return (node > num_first_branches_) ? (louds_.select1(node) - node - 1) : 0; 328 } 329 330 inline bool Trie::has_link(UInt32 node) const { 331 return (link_flags_.empty()) ? false : link_flags_[node]; 332 } 333 334 inline UInt32 Trie::get_link_id(UInt32 node) const { 335 return link_flags_.rank1(node); 336 } 337 338 inline UInt32 Trie::get_link(UInt32 node) const { 339 return get_link(node, get_link_id(node)); 340 } 341 342 inline UInt32 Trie::get_link(UInt32 node, UInt32 link_id) const { 343 return (links_[link_id] * 256) + labels_[node]; 344 } 345 346 inline bool Trie::has_link() const { 347 return !link_flags_.empty(); 348 } 349 350 inline bool Trie::has_trie() const { 351 return trie_.get() != NULL; 352 } 353 354 inline bool Trie::has_tail() const { 355 return !tail_.empty(); 356 } 357 358 } // namespace marisa_alpha 359 360 #endif // MARISA_ALPHA_TRIE_INLINE_H_ 361