Home | History | Annotate | Download | only in marisa_alpha
      1 #include "trie.h"
      2 
      3 extern "C" {
      4 
      5 namespace {
      6 
      7 class FindCallback {
      8  public:
      9   typedef int (*Func)(void *, marisa_alpha_uint32, size_t);
     10 
     11   FindCallback(Func func, void *first_arg)
     12       : func_(func), first_arg_(first_arg) {}
     13   FindCallback(const FindCallback &callback)
     14       : func_(callback.func_), first_arg_(callback.first_arg_) {}
     15 
     16   bool operator()(marisa_alpha::UInt32 key_id, std::size_t key_length) const {
     17     return func_(first_arg_, key_id, key_length) != 0;
     18   }
     19 
     20  private:
     21   Func func_;
     22   void *first_arg_;
     23 
     24   // Disallows assignment.
     25   FindCallback &operator=(const FindCallback &);
     26 };
     27 
     28 class PredictCallback {
     29  public:
     30   typedef int (*Func)(void *, marisa_alpha_uint32, const char *, size_t);
     31 
     32   PredictCallback(Func func, void *first_arg)
     33       : func_(func), first_arg_(first_arg) {}
     34   PredictCallback(const PredictCallback &callback)
     35       : func_(callback.func_), first_arg_(callback.first_arg_) {}
     36 
     37   bool operator()(marisa_alpha::UInt32 key_id, const std::string &key) const {
     38     return func_(first_arg_, key_id, key.c_str(), key.length()) != 0;
     39   }
     40 
     41  private:
     42   Func func_;
     43   void *first_arg_;
     44 
     45   // Disallows assignment.
     46   PredictCallback &operator=(const PredictCallback &);
     47 };
     48 
     49 }  // namespace
     50 
     51 struct marisa_alpha_trie_ {
     52  public:
     53   marisa_alpha_trie_() : trie(), mapper() {}
     54 
     55   marisa_alpha::Trie trie;
     56   marisa_alpha::Mapper mapper;
     57 
     58  private:
     59   // Disallows copy and assignment.
     60   marisa_alpha_trie_(const marisa_alpha_trie_ &);
     61   marisa_alpha_trie_ &operator=(const marisa_alpha_trie_ &);
     62 };
     63 
     64 marisa_alpha_status marisa_alpha_init(marisa_alpha_trie **h) {
     65   if ((h == NULL) || (*h != NULL)) {
     66     return MARISA_ALPHA_HANDLE_ERROR;
     67   }
     68   *h = new (std::nothrow) marisa_alpha_trie_();
     69   return (*h != NULL) ? MARISA_ALPHA_OK : MARISA_ALPHA_MEMORY_ERROR;
     70 }
     71 
     72 marisa_alpha_status marisa_alpha_end(marisa_alpha_trie *h) {
     73   if (h == NULL) {
     74     return MARISA_ALPHA_HANDLE_ERROR;
     75   }
     76   delete h;
     77   return MARISA_ALPHA_OK;
     78 }
     79 
     80 marisa_alpha_status marisa_alpha_build(marisa_alpha_trie *h,
     81     const char * const *keys, size_t num_keys, const size_t *key_lengths,
     82     const double *key_weights, marisa_alpha_uint32 *key_ids, int flags) try {
     83   if (h == NULL) {
     84     return MARISA_ALPHA_HANDLE_ERROR;
     85   }
     86   h->trie.build(keys, num_keys, key_lengths, key_weights, key_ids, flags);
     87   h->mapper.clear();
     88   return MARISA_ALPHA_OK;
     89 } catch (const marisa_alpha::Exception &ex) {
     90   return ex.status();
     91 }
     92 
     93 marisa_alpha_status marisa_alpha_mmap(marisa_alpha_trie *h,
     94     const char *filename, long offset, int whence) try {
     95   if (h == NULL) {
     96     return MARISA_ALPHA_HANDLE_ERROR;
     97   }
     98   h->trie.mmap(&h->mapper, filename, offset, whence);
     99   return MARISA_ALPHA_OK;
    100 } catch (const marisa_alpha::Exception &ex) {
    101   return ex.status();
    102 }
    103 
    104 marisa_alpha_status marisa_alpha_map(marisa_alpha_trie *h, const void *ptr,
    105     size_t size) try {
    106   if (h == NULL) {
    107     return MARISA_ALPHA_HANDLE_ERROR;
    108   }
    109   h->trie.map(ptr, size);
    110   h->mapper.clear();
    111   return MARISA_ALPHA_OK;
    112 } catch (const marisa_alpha::Exception &ex) {
    113   return ex.status();
    114 }
    115 
    116 marisa_alpha_status marisa_alpha_load(marisa_alpha_trie *h,
    117     const char *filename, long offset, int whence) try {
    118   if (h == NULL) {
    119     return MARISA_ALPHA_HANDLE_ERROR;
    120   }
    121   h->trie.load(filename, offset, whence);
    122   h->mapper.clear();
    123   return MARISA_ALPHA_OK;
    124 } catch (const marisa_alpha::Exception &ex) {
    125   return ex.status();
    126 }
    127 
    128 marisa_alpha_status marisa_alpha_fread(marisa_alpha_trie *h, FILE *file) try {
    129   if (h == NULL) {
    130     return MARISA_ALPHA_HANDLE_ERROR;
    131   }
    132   h->trie.fread(file);
    133   h->mapper.clear();
    134   return MARISA_ALPHA_OK;
    135 } catch (const marisa_alpha::Exception &ex) {
    136   return ex.status();
    137 }
    138 
    139 marisa_alpha_status marisa_alpha_read(marisa_alpha_trie *h, int fd) try {
    140   if (h == NULL) {
    141     return MARISA_ALPHA_HANDLE_ERROR;
    142   }
    143   h->trie.read(fd);
    144   h->mapper.clear();
    145   return MARISA_ALPHA_OK;
    146 } catch (const marisa_alpha::Exception &ex) {
    147   return ex.status();
    148 }
    149 
    150 marisa_alpha_status marisa_alpha_save(const marisa_alpha_trie *h,
    151     const char *filename, int trunc_flag, long offset, int whence) try {
    152   if (h == NULL) {
    153     return MARISA_ALPHA_HANDLE_ERROR;
    154   }
    155   h->trie.save(filename, trunc_flag != 0, offset, whence);
    156   return MARISA_ALPHA_OK;
    157 } catch (const marisa_alpha::Exception &ex) {
    158   return ex.status();
    159 }
    160 
    161 marisa_alpha_status marisa_alpha_fwrite(const marisa_alpha_trie *h,
    162     FILE *file) try {
    163   if (h == NULL) {
    164     return MARISA_ALPHA_HANDLE_ERROR;
    165   }
    166   h->trie.fwrite(file);
    167   return MARISA_ALPHA_OK;
    168 } catch (const marisa_alpha::Exception &ex) {
    169   return ex.status();
    170 }
    171 
    172 marisa_alpha_status marisa_alpha_write(const marisa_alpha_trie *h, int fd) try {
    173   if (h == NULL) {
    174     return MARISA_ALPHA_HANDLE_ERROR;
    175   }
    176   h->trie.write(fd);
    177   return MARISA_ALPHA_OK;
    178 } catch (const marisa_alpha::Exception &ex) {
    179   return ex.status();
    180 }
    181 
    182 marisa_alpha_status marisa_alpha_restore(const marisa_alpha_trie *h,
    183     marisa_alpha_uint32 key_id, char *key_buf, size_t key_buf_size,
    184     size_t *key_length) try {
    185   if (h == NULL) {
    186     return MARISA_ALPHA_HANDLE_ERROR;
    187   } else if (key_length == NULL) {
    188     return MARISA_ALPHA_PARAM_ERROR;
    189   }
    190   *key_length = h->trie.restore(key_id, key_buf, key_buf_size);
    191   return MARISA_ALPHA_OK;
    192 } catch (const marisa_alpha::Exception &ex) {
    193   return ex.status();
    194 }
    195 
    196 marisa_alpha_status marisa_alpha_lookup(const marisa_alpha_trie *h,
    197     const char *ptr, size_t length, marisa_alpha_uint32 *key_id) try {
    198   if (h == NULL) {
    199     return MARISA_ALPHA_HANDLE_ERROR;
    200   } else if (key_id == NULL) {
    201     return MARISA_ALPHA_PARAM_ERROR;
    202   }
    203   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
    204     *key_id = h->trie.lookup(ptr);
    205   } else {
    206     *key_id = h->trie.lookup(ptr, length);
    207   }
    208   return MARISA_ALPHA_OK;
    209 } catch (const marisa_alpha::Exception &ex) {
    210   return ex.status();
    211 }
    212 
    213 marisa_alpha_status marisa_alpha_find(const marisa_alpha_trie *h,
    214     const char *ptr, size_t length,
    215     marisa_alpha_uint32 *key_ids, size_t *key_lengths,
    216     size_t max_num_results, size_t *num_results) try {
    217   if (h == NULL) {
    218     return MARISA_ALPHA_HANDLE_ERROR;
    219   } else if (num_results == NULL) {
    220     return MARISA_ALPHA_PARAM_ERROR;
    221   }
    222   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
    223     *num_results = h->trie.find(ptr, key_ids, key_lengths, max_num_results);
    224   } else {
    225     *num_results = h->trie.find(ptr, length,
    226         key_ids, key_lengths, max_num_results);
    227   }
    228   return MARISA_ALPHA_OK;
    229 } catch (const marisa_alpha::Exception &ex) {
    230   return ex.status();
    231 }
    232 
    233 marisa_alpha_status marisa_alpha_find_first(const marisa_alpha_trie *h,
    234     const char *ptr, size_t length,
    235     marisa_alpha_uint32 *key_id, size_t *key_length) {
    236   if (h == NULL) {
    237     return MARISA_ALPHA_HANDLE_ERROR;
    238   } else if (key_id == NULL) {
    239     return MARISA_ALPHA_PARAM_ERROR;
    240   }
    241   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
    242     *key_id = h->trie.find_first(ptr, key_length);
    243   } else {
    244     *key_id = h->trie.find_first(ptr, length, key_length);
    245   }
    246   return MARISA_ALPHA_OK;
    247 }
    248 
    249 marisa_alpha_status marisa_alpha_find_last(const marisa_alpha_trie *h,
    250     const char *ptr, size_t length,
    251     marisa_alpha_uint32 *key_id, size_t *key_length) {
    252   if (h == NULL) {
    253     return MARISA_ALPHA_HANDLE_ERROR;
    254   } else if (key_id == NULL) {
    255     return MARISA_ALPHA_PARAM_ERROR;
    256   }
    257   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
    258     *key_id = h->trie.find_last(ptr, key_length);
    259   } else {
    260     *key_id = h->trie.find_last(ptr, length, key_length);
    261   }
    262   return MARISA_ALPHA_OK;
    263 }
    264 
    265 marisa_alpha_status marisa_alpha_find_callback(const marisa_alpha_trie *h,
    266     const char *ptr, size_t length,
    267     int (*callback)(void *, marisa_alpha_uint32, size_t),
    268     void *first_arg_to_callback) try {
    269   if (h == NULL) {
    270     return MARISA_ALPHA_HANDLE_ERROR;
    271   } else if (callback == NULL) {
    272     return MARISA_ALPHA_PARAM_ERROR;
    273   }
    274   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
    275     h->trie.find_callback(ptr,
    276         ::FindCallback(callback, first_arg_to_callback));
    277   } else {
    278     h->trie.find_callback(ptr, length,
    279         ::FindCallback(callback, first_arg_to_callback));
    280   }
    281   return MARISA_ALPHA_OK;
    282 } catch (const marisa_alpha::Exception &ex) {
    283   return ex.status();
    284 }
    285 
    286 marisa_alpha_status marisa_alpha_predict(const marisa_alpha_trie *h,
    287     const char *ptr, size_t length, marisa_alpha_uint32 *key_ids,
    288     size_t max_num_results, size_t *num_results) {
    289   return marisa_alpha_predict_breadth_first(h, ptr, length,
    290       key_ids, max_num_results, num_results);
    291 }
    292 
    293 marisa_alpha_status marisa_alpha_predict_breadth_first(
    294     const marisa_alpha_trie *h, const char *ptr, size_t length,
    295     marisa_alpha_uint32 *key_ids, size_t max_num_results,
    296     size_t *num_results) try {
    297   if (h == NULL) {
    298     return MARISA_ALPHA_HANDLE_ERROR;
    299   } else if (num_results == NULL) {
    300     return MARISA_ALPHA_PARAM_ERROR;
    301   }
    302   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
    303     *num_results = h->trie.predict_breadth_first(
    304         ptr, key_ids, NULL, max_num_results);
    305   } else {
    306     *num_results = h->trie.predict_breadth_first(
    307         ptr, length, key_ids, NULL, max_num_results);
    308   }
    309   return MARISA_ALPHA_OK;
    310 } catch (const marisa_alpha::Exception &ex) {
    311   return ex.status();
    312 }
    313 
    314 marisa_alpha_status marisa_alpha_predict_depth_first(
    315     const marisa_alpha_trie *h, const char *ptr, size_t length,
    316     marisa_alpha_uint32 *key_ids, size_t max_num_results,
    317     size_t *num_results) try {
    318   if (h == NULL) {
    319     return MARISA_ALPHA_HANDLE_ERROR;
    320   } else if (num_results == NULL) {
    321     return MARISA_ALPHA_PARAM_ERROR;
    322   }
    323   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
    324     *num_results = h->trie.predict_depth_first(
    325         ptr, key_ids, NULL, max_num_results);
    326   } else {
    327     *num_results = h->trie.predict_depth_first(
    328         ptr, length, key_ids, NULL, max_num_results);
    329   }
    330   return MARISA_ALPHA_OK;
    331 } catch (const marisa_alpha::Exception &ex) {
    332   return ex.status();
    333 }
    334 
    335 marisa_alpha_status marisa_alpha_predict_callback(const marisa_alpha_trie *h,
    336     const char *ptr, size_t length,
    337     int (*callback)(void *, marisa_alpha_uint32, const char *, size_t),
    338     void *first_arg_to_callback) try {
    339   if (h == NULL) {
    340     return MARISA_ALPHA_HANDLE_ERROR;
    341   } else if (callback == NULL) {
    342     return MARISA_ALPHA_PARAM_ERROR;
    343   }
    344   if (length == MARISA_ALPHA_ZERO_TERMINATED) {
    345     h->trie.predict_callback(ptr,
    346         ::PredictCallback(callback, first_arg_to_callback));
    347   } else {
    348     h->trie.predict_callback(ptr, length,
    349         ::PredictCallback(callback, first_arg_to_callback));
    350   }
    351   return MARISA_ALPHA_OK;
    352 } catch (const marisa_alpha::Exception &ex) {
    353   return ex.status();
    354 }
    355 
    356 size_t marisa_alpha_get_num_tries(const marisa_alpha_trie *h) {
    357   return (h != NULL) ? h->trie.num_tries() : 0;
    358 }
    359 
    360 size_t marisa_alpha_get_num_keys(const marisa_alpha_trie *h) {
    361   return (h != NULL) ? h->trie.num_keys() : 0;
    362 }
    363 
    364 size_t marisa_alpha_get_num_nodes(const marisa_alpha_trie *h) {
    365   return (h != NULL) ? h->trie.num_nodes() : 0;
    366 }
    367 
    368 size_t marisa_alpha_get_total_size(const marisa_alpha_trie *h) {
    369   return (h != NULL) ? h->trie.total_size() : 0;
    370 }
    371 
    372 marisa_alpha_status marisa_alpha_clear(marisa_alpha_trie *h) {
    373   if (h == NULL) {
    374     return MARISA_ALPHA_HANDLE_ERROR;
    375   }
    376   h->trie.clear();
    377   h->mapper.clear();
    378   return MARISA_ALPHA_OK;
    379 }
    380 
    381 }  // extern "C"
    382