Home | History | Annotate | Download | only in marisa
      1 #ifndef MARISA_TRIE_H_
      2 #define MARISA_TRIE_H_
      3 
      4 #include "base.h"
      5 
      6 #ifdef __cplusplus
      7 
      8 #include <memory>
      9 #include <vector>
     10 
     11 #include "progress.h"
     12 #include "key.h"
     13 #include "query.h"
     14 #include "container.h"
     15 #include "intvector.h"
     16 #include "bitvector.h"
     17 #include "tail.h"
     18 
     19 namespace marisa {
     20 
     21 class Trie {
     22  public:
     23   Trie();
     24 
     25   void build(const char * const *keys, std::size_t num_keys,
     26       const std::size_t *key_lengths = NULL,
     27       const double *key_weights = NULL,
     28       UInt32 *key_ids = NULL, int flags = 0);
     29 
     30   void build(const std::vector<std::string> &keys,
     31       std::vector<UInt32> *key_ids = NULL, int flags = 0);
     32   void build(const std::vector<std::pair<std::string, double> > &keys,
     33       std::vector<UInt32> *key_ids = NULL, int flags = 0);
     34 
     35   void mmap(Mapper *mapper, const char *filename,
     36       long offset = 0, int whence = SEEK_SET);
     37   void map(const void *ptr, std::size_t size);
     38   void map(Mapper &mapper);
     39 
     40   void load(const char *filename,
     41       long offset = 0, int whence = SEEK_SET);
     42   void fread(std::FILE *file);
     43   void read(int fd);
     44   void read(std::istream &stream);
     45   void read(Reader &reader);
     46 
     47   void save(const char *filename, bool trunc_flag = true,
     48       long offset = 0, int whence = SEEK_SET) const;
     49   void fwrite(std::FILE *file) const;
     50   void write(int fd) const;
     51   void write(std::ostream &stream) const;
     52   void write(Writer &writer) const;
     53 
     54   std::string operator[](UInt32 key_id) const;
     55 
     56   UInt32 operator[](const char *str) const;
     57   UInt32 operator[](const std::string &str) const;
     58 
     59   std::string restore(UInt32 key_id) const;
     60   void restore(UInt32 key_id, std::string *key) const;
     61   std::size_t restore(UInt32 key_id, char *key_buf,
     62       std::size_t key_buf_size) const;
     63 
     64   UInt32 lookup(const char *str) const;
     65   UInt32 lookup(const char *ptr, std::size_t length) const;
     66   UInt32 lookup(const std::string &str) const;
     67 
     68   std::size_t find(const char *str,
     69       UInt32 *key_ids, std::size_t *key_lengths,
     70       std::size_t max_num_results) const;
     71   std::size_t find(const char *ptr, std::size_t length,
     72       UInt32 *key_ids, std::size_t *key_lengths,
     73       std::size_t max_num_results) const;
     74   std::size_t find(const std::string &str,
     75       UInt32 *key_ids, std::size_t *key_lengths,
     76       std::size_t max_num_results) const;
     77 
     78   std::size_t find(const char *str,
     79       std::vector<UInt32> *key_ids = NULL,
     80       std::vector<std::size_t> *key_lengths = NULL,
     81       std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const;
     82   std::size_t find(const char *ptr, std::size_t length,
     83       std::vector<UInt32> *key_ids = NULL,
     84       std::vector<std::size_t> *key_lengths = NULL,
     85       std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const;
     86   std::size_t find(const std::string &str,
     87       std::vector<UInt32> *key_ids = NULL,
     88       std::vector<std::size_t> *key_lengths = NULL,
     89       std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const;
     90 
     91   UInt32 find_first(const char *str,
     92       std::size_t *key_length = NULL) const;
     93   UInt32 find_first(const char *ptr, std::size_t length,
     94       std::size_t *key_length = NULL) const;
     95   UInt32 find_first(const std::string &str,
     96       std::size_t *key_length = NULL) const;
     97 
     98   UInt32 find_last(const char *str,
     99       std::size_t *key_length = NULL) const;
    100   UInt32 find_last(const char *ptr, std::size_t length,
    101       std::size_t *key_length = NULL) const;
    102   UInt32 find_last(const std::string &str,
    103       std::size_t *key_length = NULL) const;
    104 
    105   // bool callback(UInt32 key_id, std::size_t key_length);
    106   template <typename T>
    107   std::size_t find_callback(const char *str, T callback) const;
    108   template <typename T>
    109   std::size_t find_callback(const char *ptr, std::size_t length,
    110       T callback) const;
    111   template <typename T>
    112   std::size_t find_callback(const std::string &str, T callback) const;
    113 
    114   std::size_t predict(const char *str,
    115       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
    116   std::size_t predict(const char *ptr, std::size_t length,
    117       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
    118   std::size_t predict(const std::string &str,
    119       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
    120 
    121   std::size_t predict(const char *str,
    122       std::vector<UInt32> *key_ids = NULL,
    123       std::vector<std::string> *keys = NULL,
    124       std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const;
    125   std::size_t predict(const char *ptr, std::size_t length,
    126       std::vector<UInt32> *key_ids = NULL,
    127       std::vector<std::string> *keys = NULL,
    128       std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const;
    129   std::size_t predict(const std::string &str,
    130       std::vector<UInt32> *key_ids = NULL,
    131       std::vector<std::string> *keys = NULL,
    132       std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const;
    133 
    134   std::size_t predict_breadth_first(const char *str,
    135       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
    136   std::size_t predict_breadth_first(const char *ptr, std::size_t length,
    137       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
    138   std::size_t predict_breadth_first(const std::string &str,
    139       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
    140 
    141   std::size_t predict_breadth_first(const char *str,
    142       std::vector<UInt32> *key_ids = NULL,
    143       std::vector<std::string> *keys = NULL,
    144       std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const;
    145   std::size_t predict_breadth_first(const char *ptr, std::size_t length,
    146       std::vector<UInt32> *key_ids = NULL,
    147       std::vector<std::string> *keys = NULL,
    148       std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const;
    149   std::size_t predict_breadth_first(const std::string &str,
    150       std::vector<UInt32> *key_ids = NULL,
    151       std::vector<std::string> *keys = NULL,
    152       std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const;
    153 
    154   std::size_t predict_depth_first(const char *str,
    155       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
    156   std::size_t predict_depth_first(const char *ptr, std::size_t length,
    157       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
    158   std::size_t predict_depth_first(const std::string &str,
    159       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
    160 
    161   std::size_t predict_depth_first(const char *str,
    162       std::vector<UInt32> *key_ids = NULL,
    163       std::vector<std::string> *keys = NULL,
    164       std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const;
    165   std::size_t predict_depth_first(const char *ptr, std::size_t length,
    166       std::vector<UInt32> *key_ids = NULL,
    167       std::vector<std::string> *keys = NULL,
    168       std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const;
    169   std::size_t predict_depth_first(const std::string &str,
    170       std::vector<UInt32> *key_ids = NULL,
    171       std::vector<std::string> *keys = NULL,
    172       std::size_t max_num_results = MARISA_MAX_NUM_KEYS) const;
    173 
    174   // bool callback(UInt32 key_id, const std::string &key);
    175   template <typename T>
    176   std::size_t predict_callback(const char *str, T callback) const;
    177   template <typename T>
    178   std::size_t predict_callback(const char *ptr, std::size_t length,
    179       T callback) const;
    180   template <typename T>
    181   std::size_t predict_callback(const std::string &str, T callback) const;
    182 
    183   bool empty() const;
    184   std::size_t num_tries() const;
    185   std::size_t num_keys() const;
    186   std::size_t num_nodes() const;
    187   std::size_t total_size() const;
    188 
    189   void clear();
    190   void swap(Trie *rhs);
    191 
    192   static UInt32 notfound();
    193   static std::size_t mismatch();
    194 
    195  private:
    196   BitVector louds_;
    197   Vector<UInt8> labels_;
    198   BitVector terminal_flags_;
    199   BitVector link_flags_;
    200   IntVector links_;
    201   std::auto_ptr<Trie> trie_;
    202   Tail tail_;
    203   UInt32 num_first_branches_;
    204   UInt32 num_keys_;
    205 
    206   void build_trie(Vector<Key<String> > &keys,
    207       std::vector<UInt32> *key_ids, int flags);
    208   void build_trie(Vector<Key<String> > &keys,
    209       UInt32 *key_ids, int flags);
    210 
    211   template <typename T>
    212   void build_trie(Vector<Key<T> > &keys,
    213       Vector<UInt32> *terminals, Progress &progress);
    214 
    215   template <typename T>
    216   void build_cur(Vector<Key<T> > &keys,
    217       Vector<UInt32> *terminals, Progress &progress);
    218 
    219   void build_next(Vector<Key<String> > &keys,
    220       Vector<UInt32> *terminals, Progress &progress);
    221   void build_next(Vector<Key<RString> > &rkeys,
    222       Vector<UInt32> *terminals, Progress &progress);
    223 
    224   template <typename T>
    225   UInt32 sort_keys(Vector<Key<T> > &keys) const;
    226 
    227   template <typename T>
    228   void build_terminals(const Vector<Key<T> > &keys,
    229       Vector<UInt32> *terminals) const;
    230 
    231   void restore_(UInt32 key_id, std::string *key) const;
    232   void trie_restore(UInt32 node, std::string *key) const;
    233   void tail_restore(UInt32 node, std::string *key) const;
    234 
    235   std::size_t restore_(UInt32 key_id, char *key_buf,
    236       std::size_t key_buf_size) const;
    237   void trie_restore(UInt32 node, char *key_buf,
    238       std::size_t key_buf_size, std::size_t &key_pos) const;
    239   void tail_restore(UInt32 node, char *key_buf,
    240       std::size_t key_buf_size, std::size_t &key_pos) const;
    241 
    242   template <typename T>
    243   UInt32 lookup_(T query) const;
    244   template <typename T>
    245   bool find_child(UInt32 &node, T query, std::size_t &pos) const;
    246   template <typename T>
    247   std::size_t trie_match(UInt32 node, T query, std::size_t pos) const;
    248   template <typename T>
    249   std::size_t tail_match(UInt32 node, UInt32 link_id,
    250       T query, std::size_t pos) const;
    251 
    252   template <typename T, typename U, typename V>
    253   std::size_t find_(T query, U key_ids, V key_lengths,
    254       std::size_t max_num_results) const;
    255   template <typename T>
    256   UInt32 find_first_(T query, std::size_t *key_length) const;
    257   template <typename T>
    258   UInt32 find_last_(T query, std::size_t *key_length) const;
    259   template <typename T, typename U>
    260   std::size_t find_callback_(T query, U callback) const;
    261 
    262   template <typename T, typename U, typename V>
    263   std::size_t predict_breadth_first_(T query, U key_ids, V keys,
    264       std::size_t max_num_results) const;
    265   template <typename T, typename U, typename V>
    266   std::size_t predict_depth_first_(T query, U key_ids, V keys,
    267       std::size_t max_num_results) const;
    268   template <typename T, typename U>
    269   std::size_t predict_callback_(T query, U callback) const;
    270 
    271   template <typename T>
    272   bool predict_child(UInt32 &node, T query, std::size_t &pos,
    273       std::string *key) const;
    274   template <typename T>
    275   std::size_t trie_prefix_match(UInt32 node, T query,
    276       std::size_t pos, std::string *key) const;
    277   template <typename T>
    278   std::size_t tail_prefix_match(UInt32 node, UInt32 link_id,
    279       T query, std::size_t pos, std::string *key) const;
    280 
    281   UInt32 key_id_to_node(UInt32 key_id) const;
    282   UInt32 node_to_key_id(UInt32 node) const;
    283   UInt32 louds_pos_to_node(UInt32 louds_pos, UInt32 parent_node) const;
    284 
    285   UInt32 get_child(UInt32 node) const;
    286   UInt32 get_parent(UInt32 node) const;
    287 
    288   bool has_link(UInt32 node) const;
    289   UInt32 get_link_id(UInt32 node) const;
    290   UInt32 get_link(UInt32 node) const;
    291   UInt32 get_link(UInt32 node, UInt32 link_id) const;
    292 
    293   bool has_link() const;
    294   bool has_trie() const;
    295   bool has_tail() const;
    296 
    297   // Disallows copy and assignment.
    298   Trie(const Trie &);
    299   Trie &operator=(const Trie &);
    300 };
    301 
    302 }  // namespace marisa
    303 
    304 #include "trie-inline.h"
    305 
    306 #else  // __cplusplus
    307 
    308 #include <stdio.h>
    309 
    310 #endif  // __cplusplus
    311 
    312 #ifdef __cplusplus
    313 extern "C" {
    314 #endif  // __cplusplus
    315 
    316 typedef struct marisa_trie_ marisa_trie;
    317 
    318 marisa_status marisa_init(marisa_trie **h);
    319 marisa_status marisa_end(marisa_trie *h);
    320 
    321 marisa_status marisa_build(marisa_trie *h, const char * const *keys,
    322     size_t num_keys, const size_t *key_lengths, const double *key_weights,
    323     marisa_uint32 *key_ids, int flags);
    324 
    325 marisa_status marisa_mmap(marisa_trie *h, const char *filename,
    326     long offset, int whence);
    327 marisa_status marisa_map(marisa_trie *h, const void *ptr, size_t size);
    328 
    329 marisa_status marisa_load(marisa_trie *h, const char *filename,
    330     long offset, int whence);
    331 marisa_status marisa_fread(marisa_trie *h, FILE *file);
    332 marisa_status marisa_read(marisa_trie *h, int fd);
    333 
    334 marisa_status marisa_save(const marisa_trie *h, const char *filename,
    335     int trunc_flag, long offset, int whence);
    336 marisa_status marisa_fwrite(const marisa_trie *h, FILE *file);
    337 marisa_status marisa_write(const marisa_trie *h, int fd);
    338 
    339 marisa_status marisa_restore(const marisa_trie *h, marisa_uint32 key_id,
    340     char *key_buf, size_t key_buf_size, size_t *key_length);
    341 
    342 marisa_status marisa_lookup(const marisa_trie *h,
    343     const char *ptr, size_t length, marisa_uint32 *key_id);
    344 
    345 marisa_status marisa_find(const marisa_trie *h,
    346     const char *ptr, size_t length,
    347     marisa_uint32 *key_ids, size_t *key_lengths,
    348     size_t max_num_results, size_t *num_results);
    349 marisa_status marisa_find_first(const marisa_trie *h,
    350     const char *ptr, size_t length,
    351     marisa_uint32 *key_id, size_t *key_length);
    352 marisa_status marisa_find_last(const marisa_trie *h,
    353     const char *ptr, size_t length,
    354     marisa_uint32 *key_id, size_t *key_length);
    355 marisa_status marisa_find_callback(const marisa_trie *h,
    356     const char *ptr, size_t length,
    357     int (*callback)(void *, marisa_uint32, size_t),
    358     void *first_arg_to_callback);
    359 
    360 marisa_status marisa_predict(const marisa_trie *h,
    361     const char *ptr, size_t length, marisa_uint32 *key_ids,
    362     size_t max_num_results, size_t *num_results);
    363 marisa_status marisa_predict_breadth_first(const marisa_trie *h,
    364     const char *ptr, size_t length, marisa_uint32 *key_ids,
    365     size_t max_num_results, size_t *num_results);
    366 marisa_status marisa_predict_depth_first(const marisa_trie *h,
    367     const char *ptr, size_t length, marisa_uint32 *key_ids,
    368     size_t max_num_results, size_t *num_results);
    369 marisa_status marisa_predict_callback(const marisa_trie *h,
    370     const char *ptr, size_t length,
    371     int (*callback)(void *, marisa_uint32, const char *, size_t),
    372     void *first_arg_to_callback);
    373 
    374 size_t marisa_get_num_tries(const marisa_trie *h);
    375 size_t marisa_get_num_keys(const marisa_trie *h);
    376 size_t marisa_get_num_nodes(const marisa_trie *h);
    377 size_t marisa_get_total_size(const marisa_trie *h);
    378 
    379 marisa_status marisa_clear(marisa_trie *h);
    380 
    381 #ifdef __cplusplus
    382 }  // extern "C"
    383 #endif  // __cplusplus
    384 
    385 #endif  // MARISA_TRIE_H_
    386