Home | History | Annotate | Download | only in marisa
      1 #include "trie.h"
      2 
      3 extern "C" {
      4 
      5 namespace {
      6 
      7 class FindCallback {
      8  public:
      9   typedef int (*Func)(void *, marisa_uint32, size_t);
     10 
     11   FindCallback(Func func, void *first_arg)
     12       : func_(func), first_arg_(first_arg) {}
     13   FindCallback(const FindCallback &callback)
     14       : func_(callback.func_), first_arg_(callback.first_arg_) {}
     15 
     16   bool operator()(marisa::UInt32 key_id, std::size_t key_length) const {
     17     return func_(first_arg_, key_id, key_length) != 0;
     18   }
     19 
     20  private:
     21   Func func_;
     22   void *first_arg_;
     23 
     24   // Disallows assignment.
     25   FindCallback &operator=(const FindCallback &);
     26 };
     27 
     28 class PredictCallback {
     29  public:
     30   typedef int (*Func)(void *, marisa_uint32, const char *, size_t);
     31 
     32   PredictCallback(Func func, void *first_arg)
     33       : func_(func), first_arg_(first_arg) {}
     34   PredictCallback(const PredictCallback &callback)
     35       : func_(callback.func_), first_arg_(callback.first_arg_) {}
     36 
     37   bool operator()(marisa::UInt32 key_id, const std::string &key) const {
     38     return func_(first_arg_, key_id, key.c_str(), key.length()) != 0;
     39   }
     40 
     41  private:
     42   Func func_;
     43   void *first_arg_;
     44 
     45   // Disallows assignment.
     46   PredictCallback &operator=(const PredictCallback &);
     47 };
     48 
     49 }  // namespace
     50 
     51 struct marisa_trie_ {
     52  public:
     53   marisa_trie_() : trie(), mapper() {}
     54 
     55   marisa::Trie trie;
     56   marisa::Mapper mapper;
     57 
     58  private:
     59   // Disallows copy and assignment.
     60   marisa_trie_(const marisa_trie_ &);
     61   marisa_trie_ &operator=(const marisa_trie_ &);
     62 };
     63 
     64 marisa_status marisa_init(marisa_trie **h) {
     65   if ((h == NULL) || (*h != NULL)) {
     66     return MARISA_HANDLE_ERROR;
     67   }
     68   *h = new (std::nothrow) marisa_trie_();
     69   return (*h != NULL) ? MARISA_OK : MARISA_MEMORY_ERROR;
     70 }
     71 
     72 marisa_status marisa_end(marisa_trie *h) {
     73   if (h == NULL) {
     74     return MARISA_HANDLE_ERROR;
     75   }
     76   delete h;
     77   return MARISA_OK;
     78 }
     79 
     80 marisa_status marisa_build(marisa_trie *h, const char * const *keys,
     81     size_t num_keys, const size_t *key_lengths, const double *key_weights,
     82     marisa_uint32 *key_ids, int flags) {
     83   if (h == NULL) {
     84     return MARISA_HANDLE_ERROR;
     85   }
     86   h->trie.build(keys, num_keys, key_lengths, key_weights, key_ids, flags);
     87   h->mapper.clear();
     88   return MARISA_OK;
     89 }
     90 
     91 marisa_status marisa_mmap(marisa_trie *h, const char *filename,
     92     long offset, int whence) {
     93   if (h == NULL) {
     94     return MARISA_HANDLE_ERROR;
     95   }
     96   h->trie.mmap(&h->mapper, filename, offset, whence);
     97   return MARISA_OK;
     98 }
     99 
    100 marisa_status marisa_map(marisa_trie *h, const void *ptr, size_t size) {
    101   if (h == NULL) {
    102     return MARISA_HANDLE_ERROR;
    103   }
    104   h->trie.map(ptr, size);
    105   h->mapper.clear();
    106   return MARISA_OK;
    107 }
    108 
    109 marisa_status marisa_load(marisa_trie *h, const char *filename,
    110     long offset, int whence) {
    111   if (h == NULL) {
    112     return MARISA_HANDLE_ERROR;
    113   }
    114   h->trie.load(filename, offset, whence);
    115   h->mapper.clear();
    116   return MARISA_OK;
    117 }
    118 
    119 marisa_status marisa_fread(marisa_trie *h, FILE *file) {
    120   if (h == NULL) {
    121     return MARISA_HANDLE_ERROR;
    122   }
    123   h->trie.fread(file);
    124   h->mapper.clear();
    125   return MARISA_OK;
    126 }
    127 
    128 marisa_status marisa_read(marisa_trie *h, int fd) {
    129   if (h == NULL) {
    130     return MARISA_HANDLE_ERROR;
    131   }
    132   h->trie.read(fd);
    133   h->mapper.clear();
    134   return MARISA_OK;
    135 }
    136 
    137 marisa_status marisa_save(const marisa_trie *h, const char *filename,
    138     int trunc_flag, long offset, int whence) {
    139   if (h == NULL) {
    140     return MARISA_HANDLE_ERROR;
    141   }
    142   h->trie.save(filename, trunc_flag != 0, offset, whence);
    143   return MARISA_OK;
    144 }
    145 
    146 marisa_status marisa_fwrite(const marisa_trie *h, FILE *file) {
    147   if (h == NULL) {
    148     return MARISA_HANDLE_ERROR;
    149   }
    150   h->trie.fwrite(file);
    151   return MARISA_OK;
    152 }
    153 
    154 marisa_status marisa_write(const marisa_trie *h, int fd) {
    155   if (h == NULL) {
    156     return MARISA_HANDLE_ERROR;
    157   }
    158   h->trie.write(fd);
    159   return MARISA_OK;
    160 }
    161 
    162 marisa_status marisa_restore(const marisa_trie *h, marisa_uint32 key_id,
    163     char *key_buf, size_t key_buf_size, size_t *key_length) {
    164   if (h == NULL) {
    165     return MARISA_HANDLE_ERROR;
    166   } else if (key_length == NULL) {
    167     return MARISA_PARAM_ERROR;
    168   }
    169   *key_length = h->trie.restore(key_id, key_buf, key_buf_size);
    170   return MARISA_OK;
    171 }
    172 
    173 marisa_status marisa_lookup(const marisa_trie *h,
    174     const char *ptr, size_t length, marisa_uint32 *key_id) {
    175   if (h == NULL) {
    176     return MARISA_HANDLE_ERROR;
    177   } else if (key_id == NULL) {
    178     return MARISA_PARAM_ERROR;
    179   }
    180   if (length == MARISA_ZERO_TERMINATED) {
    181     *key_id = h->trie.lookup(ptr);
    182   } else {
    183     *key_id = h->trie.lookup(ptr, length);
    184   }
    185   return MARISA_OK;
    186 }
    187 
    188 marisa_status marisa_find(const marisa_trie *h,
    189     const char *ptr, size_t length,
    190     marisa_uint32 *key_ids, size_t *key_lengths,
    191     size_t max_num_results, size_t *num_results) {
    192   if (h == NULL) {
    193     return MARISA_HANDLE_ERROR;
    194   } else if (num_results == NULL) {
    195     return MARISA_PARAM_ERROR;
    196   }
    197   if (length == MARISA_ZERO_TERMINATED) {
    198     *num_results = h->trie.find(ptr, key_ids, key_lengths, max_num_results);
    199   } else {
    200     *num_results = h->trie.find(ptr, length,
    201         key_ids, key_lengths, max_num_results);
    202   }
    203   return MARISA_OK;
    204 }
    205 
    206 marisa_status marisa_find_first(const marisa_trie *h,
    207     const char *ptr, size_t length,
    208     marisa_uint32 *key_id, size_t *key_length) {
    209   if (h == NULL) {
    210     return MARISA_HANDLE_ERROR;
    211   } else if (key_id == NULL) {
    212     return MARISA_PARAM_ERROR;
    213   }
    214   if (length == MARISA_ZERO_TERMINATED) {
    215     *key_id = h->trie.find_first(ptr, key_length);
    216   } else {
    217     *key_id = h->trie.find_first(ptr, length, key_length);
    218   }
    219   return MARISA_OK;
    220 }
    221 
    222 marisa_status marisa_find_last(const marisa_trie *h,
    223     const char *ptr, size_t length,
    224     marisa_uint32 *key_id, size_t *key_length) {
    225   if (h == NULL) {
    226     return MARISA_HANDLE_ERROR;
    227   } else if (key_id == NULL) {
    228     return MARISA_PARAM_ERROR;
    229   }
    230   if (length == MARISA_ZERO_TERMINATED) {
    231     *key_id = h->trie.find_last(ptr, key_length);
    232   } else {
    233     *key_id = h->trie.find_last(ptr, length, key_length);
    234   }
    235   return MARISA_OK;
    236 }
    237 
    238 marisa_status marisa_find_callback(const marisa_trie *h,
    239     const char *ptr, size_t length,
    240     int (*callback)(void *, marisa_uint32, size_t),
    241     void *first_arg_to_callback) {
    242   if (h == NULL) {
    243     return MARISA_HANDLE_ERROR;
    244   } else if (callback == NULL) {
    245     return MARISA_PARAM_ERROR;
    246   }
    247   if (length == MARISA_ZERO_TERMINATED) {
    248     h->trie.find_callback(ptr,
    249         ::FindCallback(callback, first_arg_to_callback));
    250   } else {
    251     h->trie.find_callback(ptr, length,
    252         ::FindCallback(callback, first_arg_to_callback));
    253   }
    254   return MARISA_OK;
    255 }
    256 
    257 marisa_status marisa_predict(const marisa_trie *h,
    258     const char *ptr, size_t length, marisa_uint32 *key_ids,
    259     size_t max_num_results, size_t *num_results) {
    260   return marisa_predict_breadth_first(h, ptr, length,
    261       key_ids, max_num_results, num_results);
    262 }
    263 
    264 marisa_status marisa_predict_breadth_first(const marisa_trie *h,
    265     const char *ptr, size_t length, marisa_uint32 *key_ids,
    266     size_t max_num_results, size_t *num_results) {
    267   if (h == NULL) {
    268     return MARISA_HANDLE_ERROR;
    269   } else if (num_results == NULL) {
    270     return MARISA_PARAM_ERROR;
    271   }
    272   if (length == MARISA_ZERO_TERMINATED) {
    273     *num_results = h->trie.predict_breadth_first(
    274         ptr, key_ids, NULL, max_num_results);
    275   } else {
    276     *num_results = h->trie.predict_breadth_first(
    277         ptr, length, key_ids, NULL, max_num_results);
    278   }
    279   return MARISA_OK;
    280 }
    281 
    282 marisa_status marisa_predict_depth_first(const marisa_trie *h,
    283     const char *ptr, size_t length, marisa_uint32 *key_ids,
    284     size_t max_num_results, size_t *num_results) {
    285   if (h == NULL) {
    286     return MARISA_HANDLE_ERROR;
    287   } else if (num_results == NULL) {
    288     return MARISA_PARAM_ERROR;
    289   }
    290   if (length == MARISA_ZERO_TERMINATED) {
    291     *num_results = h->trie.predict_depth_first(
    292         ptr, key_ids, NULL, max_num_results);
    293   } else {
    294     *num_results = h->trie.predict_depth_first(
    295         ptr, length, key_ids, NULL, max_num_results);
    296   }
    297   return MARISA_OK;
    298 }
    299 
    300 marisa_status marisa_predict_callback(const marisa_trie *h,
    301     const char *ptr, size_t length,
    302     int (*callback)(void *, marisa_uint32, const char *, size_t),
    303     void *first_arg_to_callback) {
    304   if (h == NULL) {
    305     return MARISA_HANDLE_ERROR;
    306   } else if (callback == NULL) {
    307     return MARISA_PARAM_ERROR;
    308   }
    309   if (length == MARISA_ZERO_TERMINATED) {
    310     h->trie.predict_callback(ptr,
    311         ::PredictCallback(callback, first_arg_to_callback));
    312   } else {
    313     h->trie.predict_callback(ptr, length,
    314         ::PredictCallback(callback, first_arg_to_callback));
    315   }
    316   return MARISA_OK;
    317 }
    318 
    319 size_t marisa_get_num_tries(const marisa_trie *h) {
    320   return (h != NULL) ? h->trie.num_tries() : 0;
    321 }
    322 
    323 size_t marisa_get_num_keys(const marisa_trie *h) {
    324   return (h != NULL) ? h->trie.num_keys() : 0;
    325 }
    326 
    327 size_t marisa_get_num_nodes(const marisa_trie *h) {
    328   return (h != NULL) ? h->trie.num_nodes() : 0;
    329 }
    330 
    331 size_t marisa_get_total_size(const marisa_trie *h) {
    332   return (h != NULL) ? h->trie.total_size() : 0;
    333 }
    334 
    335 marisa_status marisa_clear(marisa_trie *h) {
    336   if (h == NULL) {
    337     return MARISA_HANDLE_ERROR;
    338   }
    339   h->trie.clear();
    340   h->mapper.clear();
    341   return MARISA_OK;
    342 }
    343 
    344 }  // extern "C"
    345