1 #ifndef MARISA_TRIE_H_ 2 #define MARISA_TRIE_H_ 3 4 #include "base.h" 5 6 #ifdef __cplusplus 7 8 #include <memory> 9 #include <vector> 10 11 #include "progress.h" 12 #include "key.h" 13 #include "query.h" 14 #include "container.h" 15 #include "intvector.h" 16 #include "bitvector.h" 17 #include "tail.h" 18 19 namespace marisa { 20 21 class Trie { 22 public: 23 Trie(); 24 25 void build(const char * const *keys, std::size_t num_keys, 26 const std::size_t *key_lengths = NULL, 27 const double *key_weights = NULL, 28 UInt32 *key_ids = NULL, int flags = 0); 29 30 void build(const std::vector<std::string> &keys, 31 std::vector<UInt32> *key_ids = NULL, int flags = 0); 32 void build(const std::vector<std::pair<std::string, double> > &keys, 33 std::vector<UInt32> *key_ids = NULL, int flags = 0); 34 35 void mmap(Mapper *mapper, const char *filename, 36 long offset = 0, int whence = SEEK_SET); 37 void map(const void *ptr, std::size_t size); 38 void map(Mapper &mapper); 39 40 void load(const char *filename, 41 long offset = 0, int whence = SEEK_SET); 42 void fread(std::FILE *file); 43 void read(int fd); 44 void read(std::istream &stream); 45 void read(Reader &reader); 46 47 void save(const char *filename, bool trunc_flag = true, 48 long offset = 0, int whence = SEEK_SET) const; 49 void fwrite(std::FILE *file) const; 50 void write(int fd) const; 51 void write(std::ostream &stream) const; 52 void write(Writer &writer) const; 53 54 std::string operator[](UInt32 key_id) const; 55 56 UInt32 operator[](const char *str) const; 57 UInt32 operator[](const std::string &str) const; 58 59 std::string restore(UInt32 key_id) const; 60 void restore(UInt32 key_id, std::string *key) const; 61 std::size_t restore(UInt32 key_id, char *key_buf, 62 std::size_t key_buf_size) const; 63 64 UInt32 lookup(const char *str) const; 65 UInt32 lookup(const char *ptr, std::size_t length) const; 66 UInt32 lookup(const std::string &str) const; 67 68 std::size_t find(const char *str, 69 UInt32 *key_ids, std::size_t *key_lengths, 70 std::size_t max_num_results) const; 71 std::size_t find(const char *ptr, std::size_t length, 72 UInt32 *key_ids, std::size_t *key_lengths, 73 std::size_t max_num_results) const; 74 std::size_t find(const std::string &str, 75 UInt32 *key_ids, std::size_t *key_lengths, 76 std::size_t max_num_results) const; 77 78 std::size_t find(const char *str, 79 std::vector<UInt32> *key_ids = NULL, 80 std::vector<std::size_t> *key_lengths = NULL, 81 std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const; 82 std::size_t find(const char *ptr, std::size_t length, 83 std::vector<UInt32> *key_ids = NULL, 84 std::vector<std::size_t> *key_lengths = NULL, 85 std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const; 86 std::size_t find(const std::string &str, 87 std::vector<UInt32> *key_ids = NULL, 88 std::vector<std::size_t> *key_lengths = NULL, 89 std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const; 90 91 UInt32 find_first(const char *str, 92 std::size_t *key_length = NULL) const; 93 UInt32 find_first(const char *ptr, std::size_t length, 94 std::size_t *key_length = NULL) const; 95 UInt32 find_first(const std::string &str, 96 std::size_t *key_length = NULL) const; 97 98 UInt32 find_last(const char *str, 99 std::size_t *key_length = NULL) const; 100 UInt32 find_last(const char *ptr, std::size_t length, 101 std::size_t *key_length = NULL) const; 102 UInt32 find_last(const std::string &str, 103 std::size_t *key_length = NULL) const; 104 105 // bool callback(UInt32 key_id, std::size_t key_length); 106 template <typename T> 107 std::size_t find_callback(const char *str, T callback) const; 108 template <typename T> 109 std::size_t find_callback(const char *ptr, std::size_t length, 110 T callback) const; 111 template <typename T> 112 std::size_t find_callback(const std::string &str, T callback) const; 113 114 std::size_t predict(const char *str, 115 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const; 116 std::size_t predict(const char *ptr, std::size_t length, 117 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const; 118 std::size_t predict(const std::string &str, 119 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const; 120 121 std::size_t predict(const char *str, 122 std::vector<UInt32> *key_ids = NULL, 123 std::vector<std::string> *keys = NULL, 124 std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const; 125 std::size_t predict(const char *ptr, std::size_t length, 126 std::vector<UInt32> *key_ids = NULL, 127 std::vector<std::string> *keys = NULL, 128 std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const; 129 std::size_t predict(const std::string &str, 130 std::vector<UInt32> *key_ids = NULL, 131 std::vector<std::string> *keys = NULL, 132 std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const; 133 134 std::size_t predict_breadth_first(const char *str, 135 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const; 136 std::size_t predict_breadth_first(const char *ptr, std::size_t length, 137 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const; 138 std::size_t predict_breadth_first(const std::string &str, 139 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const; 140 141 std::size_t predict_breadth_first(const char *str, 142 std::vector<UInt32> *key_ids = NULL, 143 std::vector<std::string> *keys = NULL, 144 std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const; 145 std::size_t predict_breadth_first(const char *ptr, std::size_t length, 146 std::vector<UInt32> *key_ids = NULL, 147 std::vector<std::string> *keys = NULL, 148 std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const; 149 std::size_t predict_breadth_first(const std::string &str, 150 std::vector<UInt32> *key_ids = NULL, 151 std::vector<std::string> *keys = NULL, 152 std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const; 153 154 std::size_t predict_depth_first(const char *str, 155 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const; 156 std::size_t predict_depth_first(const char *ptr, std::size_t length, 157 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const; 158 std::size_t predict_depth_first(const std::string &str, 159 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const; 160 161 std::size_t predict_depth_first(const char *str, 162 std::vector<UInt32> *key_ids = NULL, 163 std::vector<std::string> *keys = NULL, 164 std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const; 165 std::size_t predict_depth_first(const char *ptr, std::size_t length, 166 std::vector<UInt32> *key_ids = NULL, 167 std::vector<std::string> *keys = NULL, 168 std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const; 169 std::size_t predict_depth_first(const std::string &str, 170 std::vector<UInt32> *key_ids = NULL, 171 std::vector<std::string> *keys = NULL, 172 std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const; 173 174 // bool callback(UInt32 key_id, const std::string &key); 175 template <typename T> 176 std::size_t predict_callback(const char *str, T callback) const; 177 template <typename T> 178 std::size_t predict_callback(const char *ptr, std::size_t length, 179 T callback) const; 180 template <typename T> 181 std::size_t predict_callback(const std::string &str, T callback) const; 182 183 bool empty() const; 184 std::size_t num_tries() const; 185 std::size_t num_keys() const; 186 std::size_t num_nodes() const; 187 std::size_t total_size() const; 188 189 void clear(); 190 void swap(Trie *rhs); 191 192 static UInt32 notfound(); 193 static std::size_t mismatch(); 194 195 private: 196 BitVector louds_; 197 Vector<UInt8> labels_; 198 BitVector terminal_flags_; 199 BitVector link_flags_; 200 IntVector links_; 201 std::auto_ptr<Trie> trie_; 202 Tail tail_; 203 UInt32 num_first_branches_; 204 UInt32 num_keys_; 205 206 void build_trie(Vector<Key<String> > &keys, 207 std::vector<UInt32> *key_ids, int flags); 208 void build_trie(Vector<Key<String> > &keys, 209 UInt32 *key_ids, int flags); 210 211 template <typename T> 212 void build_trie(Vector<Key<T> > &keys, 213 Vector<UInt32> *terminals, Progress &progress); 214 215 template <typename T> 216 void build_cur(Vector<Key<T> > &keys, 217 Vector<UInt32> *terminals, Progress &progress); 218 219 void build_next(Vector<Key<String> > &keys, 220 Vector<UInt32> *terminals, Progress &progress); 221 void build_next(Vector<Key<RString> > &rkeys, 222 Vector<UInt32> *terminals, Progress &progress); 223 224 template <typename T> 225 UInt32 sort_keys(Vector<Key<T> > &keys) const; 226 227 template <typename T> 228 void build_terminals(const Vector<Key<T> > &keys, 229 Vector<UInt32> *terminals) const; 230 231 void restore_(UInt32 key_id, std::string *key) const; 232 void trie_restore(UInt32 node, std::string *key) const; 233 void tail_restore(UInt32 node, std::string *key) const; 234 235 std::size_t restore_(UInt32 key_id, char *key_buf, 236 std::size_t key_buf_size) const; 237 void trie_restore(UInt32 node, char *key_buf, 238 std::size_t key_buf_size, std::size_t &key_pos) const; 239 void tail_restore(UInt32 node, char *key_buf, 240 std::size_t key_buf_size, std::size_t &key_pos) const; 241 242 template <typename T> 243 UInt32 lookup_(T query) const; 244 template <typename T> 245 bool find_child(UInt32 &node, T query, std::size_t &pos) const; 246 template <typename T> 247 std::size_t trie_match(UInt32 node, T query, std::size_t pos) const; 248 template <typename T> 249 std::size_t tail_match(UInt32 node, UInt32 link_id, 250 T query, std::size_t pos) const; 251 252 template <typename T, typename U, typename V> 253 std::size_t find_(T query, U key_ids, V key_lengths, 254 std::size_t max_num_results) const; 255 template <typename T> 256 UInt32 find_first_(T query, std::size_t *key_length) const; 257 template <typename T> 258 UInt32 find_last_(T query, std::size_t *key_length) const; 259 template <typename T, typename U> 260 std::size_t find_callback_(T query, U callback) const; 261 262 template <typename T, typename U, typename V> 263 std::size_t predict_breadth_first_(T query, U key_ids, V keys, 264 std::size_t max_num_results) const; 265 template <typename T, typename U, typename V> 266 std::size_t predict_depth_first_(T query, U key_ids, V keys, 267 std::size_t max_num_results) const; 268 template <typename T, typename U> 269 std::size_t predict_callback_(T query, U callback) const; 270 271 template <typename T> 272 bool predict_child(UInt32 &node, T query, std::size_t &pos, 273 std::string *key) const; 274 template <typename T> 275 std::size_t trie_prefix_match(UInt32 node, T query, 276 std::size_t pos, std::string *key) const; 277 template <typename T> 278 std::size_t tail_prefix_match(UInt32 node, UInt32 link_id, 279 T query, std::size_t pos, std::string *key) const; 280 281 UInt32 key_id_to_node(UInt32 key_id) const; 282 UInt32 node_to_key_id(UInt32 node) const; 283 UInt32 louds_pos_to_node(UInt32 louds_pos, UInt32 parent_node) const; 284 285 UInt32 get_child(UInt32 node) const; 286 UInt32 get_parent(UInt32 node) const; 287 288 bool has_link(UInt32 node) const; 289 UInt32 get_link_id(UInt32 node) const; 290 UInt32 get_link(UInt32 node) const; 291 UInt32 get_link(UInt32 node, UInt32 link_id) const; 292 293 bool has_link() const; 294 bool has_trie() const; 295 bool has_tail() const; 296 297 // Disallows copy and assignment. 298 Trie(const Trie &); 299 Trie &operator=(const Trie &); 300 }; 301 302 } // namespace marisa 303 304 #include "trie-inline.h" 305 306 #else // __cplusplus 307 308 #include <stdio.h> 309 310 #endif // __cplusplus 311 312 #ifdef __cplusplus 313 extern "C" { 314 #endif // __cplusplus 315 316 typedef struct marisa_trie_ marisa_trie; 317 318 marisa_status marisa_init(marisa_trie **h); 319 marisa_status marisa_end(marisa_trie *h); 320 321 marisa_status marisa_build(marisa_trie *h, const char * const *keys, 322 size_t num_keys, const size_t *key_lengths, const double *key_weights, 323 marisa_uint32 *key_ids, int flags); 324 325 marisa_status marisa_mmap(marisa_trie *h, const char *filename, 326 long offset, int whence); 327 marisa_status marisa_map(marisa_trie *h, const void *ptr, size_t size); 328 329 marisa_status marisa_load(marisa_trie *h, const char *filename, 330 long offset, int whence); 331 marisa_status marisa_fread(marisa_trie *h, FILE *file); 332 marisa_status marisa_read(marisa_trie *h, int fd); 333 334 marisa_status marisa_save(const marisa_trie *h, const char *filename, 335 int trunc_flag, long offset, int whence); 336 marisa_status marisa_fwrite(const marisa_trie *h, FILE *file); 337 marisa_status marisa_write(const marisa_trie *h, int fd); 338 339 marisa_status marisa_restore(const marisa_trie *h, marisa_uint32 key_id, 340 char *key_buf, size_t key_buf_size, size_t *key_length); 341 342 marisa_status marisa_lookup(const marisa_trie *h, 343 const char *ptr, size_t length, marisa_uint32 *key_id); 344 345 marisa_status marisa_find(const marisa_trie *h, 346 const char *ptr, size_t length, 347 marisa_uint32 *key_ids, size_t *key_lengths, 348 size_t max_num_results, size_t *num_results); 349 marisa_status marisa_find_first(const marisa_trie *h, 350 const char *ptr, size_t length, 351 marisa_uint32 *key_id, size_t *key_length); 352 marisa_status marisa_find_last(const marisa_trie *h, 353 const char *ptr, size_t length, 354 marisa_uint32 *key_id, size_t *key_length); 355 marisa_status marisa_find_callback(const marisa_trie *h, 356 const char *ptr, size_t length, 357 int (*callback)(void *, marisa_uint32, size_t), 358 void *first_arg_to_callback); 359 360 marisa_status marisa_predict(const marisa_trie *h, 361 const char *ptr, size_t length, marisa_uint32 *key_ids, 362 size_t max_num_results, size_t *num_results); 363 marisa_status marisa_predict_breadth_first(const marisa_trie *h, 364 const char *ptr, size_t length, marisa_uint32 *key_ids, 365 size_t max_num_results, size_t *num_results); 366 marisa_status marisa_predict_depth_first(const marisa_trie *h, 367 const char *ptr, size_t length, marisa_uint32 *key_ids, 368 size_t max_num_results, size_t *num_results); 369 marisa_status marisa_predict_callback(const marisa_trie *h, 370 const char *ptr, size_t length, 371 int (*callback)(void *, marisa_uint32, const char *, size_t), 372 void *first_arg_to_callback); 373 374 size_t marisa_get_num_tries(const marisa_trie *h); 375 size_t marisa_get_num_keys(const marisa_trie *h); 376 size_t marisa_get_num_nodes(const marisa_trie *h); 377 size_t marisa_get_total_size(const marisa_trie *h); 378 379 marisa_status marisa_clear(marisa_trie *h); 380 381 #ifdef __cplusplus 382 } // extern "C" 383 #endif // __cplusplus 384 385 #endif // MARISA_TRIE_H_ 386