1 #include "trie.h" 2 3 extern "C" { 4 5 namespace { 6 7 class FindCallback { 8 public: 9 typedef int (*Func)(void *, marisa_uint32, size_t); 10 11 FindCallback(Func func, void *first_arg) 12 : func_(func), first_arg_(first_arg) {} 13 FindCallback(const FindCallback &callback) 14 : func_(callback.func_), first_arg_(callback.first_arg_) {} 15 16 bool operator()(marisa::UInt32 key_id, std::size_t key_length) const { 17 return func_(first_arg_, key_id, key_length) != 0; 18 } 19 20 private: 21 Func func_; 22 void *first_arg_; 23 24 // Disallows assignment. 25 FindCallback &operator=(const FindCallback &); 26 }; 27 28 class PredictCallback { 29 public: 30 typedef int (*Func)(void *, marisa_uint32, const char *, size_t); 31 32 PredictCallback(Func func, void *first_arg) 33 : func_(func), first_arg_(first_arg) {} 34 PredictCallback(const PredictCallback &callback) 35 : func_(callback.func_), first_arg_(callback.first_arg_) {} 36 37 bool operator()(marisa::UInt32 key_id, const std::string &key) const { 38 return func_(first_arg_, key_id, key.c_str(), key.length()) != 0; 39 } 40 41 private: 42 Func func_; 43 void *first_arg_; 44 45 // Disallows assignment. 46 PredictCallback &operator=(const PredictCallback &); 47 }; 48 49 } // namespace 50 51 struct marisa_trie_ { 52 public: 53 marisa_trie_() : trie(), mapper() {} 54 55 marisa::Trie trie; 56 marisa::Mapper mapper; 57 58 private: 59 // Disallows copy and assignment. 60 marisa_trie_(const marisa_trie_ &); 61 marisa_trie_ &operator=(const marisa_trie_ &); 62 }; 63 64 marisa_status marisa_init(marisa_trie **h) { 65 if ((h == NULL) || (*h != NULL)) { 66 return MARISA_HANDLE_ERROR; 67 } 68 *h = new (std::nothrow) marisa_trie_(); 69 return (*h != NULL) ? MARISA_OK : MARISA_MEMORY_ERROR; 70 } 71 72 marisa_status marisa_end(marisa_trie *h) { 73 if (h == NULL) { 74 return MARISA_HANDLE_ERROR; 75 } 76 delete h; 77 return MARISA_OK; 78 } 79 80 marisa_status marisa_build(marisa_trie *h, const char * const *keys, 81 size_t num_keys, const size_t *key_lengths, const double *key_weights, 82 marisa_uint32 *key_ids, int flags) { 83 if (h == NULL) { 84 return MARISA_HANDLE_ERROR; 85 } 86 h->trie.build(keys, num_keys, key_lengths, key_weights, key_ids, flags); 87 h->mapper.clear(); 88 return MARISA_OK; 89 } 90 91 marisa_status marisa_mmap(marisa_trie *h, const char *filename, 92 long offset, int whence) { 93 if (h == NULL) { 94 return MARISA_HANDLE_ERROR; 95 } 96 h->trie.mmap(&h->mapper, filename, offset, whence); 97 return MARISA_OK; 98 } 99 100 marisa_status marisa_map(marisa_trie *h, const void *ptr, size_t size) { 101 if (h == NULL) { 102 return MARISA_HANDLE_ERROR; 103 } 104 h->trie.map(ptr, size); 105 h->mapper.clear(); 106 return MARISA_OK; 107 } 108 109 marisa_status marisa_load(marisa_trie *h, const char *filename, 110 long offset, int whence) { 111 if (h == NULL) { 112 return MARISA_HANDLE_ERROR; 113 } 114 h->trie.load(filename, offset, whence); 115 h->mapper.clear(); 116 return MARISA_OK; 117 } 118 119 marisa_status marisa_fread(marisa_trie *h, FILE *file) { 120 if (h == NULL) { 121 return MARISA_HANDLE_ERROR; 122 } 123 h->trie.fread(file); 124 h->mapper.clear(); 125 return MARISA_OK; 126 } 127 128 marisa_status marisa_read(marisa_trie *h, int fd) { 129 if (h == NULL) { 130 return MARISA_HANDLE_ERROR; 131 } 132 h->trie.read(fd); 133 h->mapper.clear(); 134 return MARISA_OK; 135 } 136 137 marisa_status marisa_save(const marisa_trie *h, const char *filename, 138 int trunc_flag, long offset, int whence) { 139 if (h == NULL) { 140 return MARISA_HANDLE_ERROR; 141 } 142 h->trie.save(filename, trunc_flag != 0, offset, whence); 143 return MARISA_OK; 144 } 145 146 marisa_status marisa_fwrite(const marisa_trie *h, FILE *file) { 147 if (h == NULL) { 148 return MARISA_HANDLE_ERROR; 149 } 150 h->trie.fwrite(file); 151 return MARISA_OK; 152 } 153 154 marisa_status marisa_write(const marisa_trie *h, int fd) { 155 if (h == NULL) { 156 return MARISA_HANDLE_ERROR; 157 } 158 h->trie.write(fd); 159 return MARISA_OK; 160 } 161 162 marisa_status marisa_restore(const marisa_trie *h, marisa_uint32 key_id, 163 char *key_buf, size_t key_buf_size, size_t *key_length) { 164 if (h == NULL) { 165 return MARISA_HANDLE_ERROR; 166 } else if (key_length == NULL) { 167 return MARISA_PARAM_ERROR; 168 } 169 *key_length = h->trie.restore(key_id, key_buf, key_buf_size); 170 return MARISA_OK; 171 } 172 173 marisa_status marisa_lookup(const marisa_trie *h, 174 const char *ptr, size_t length, marisa_uint32 *key_id) { 175 if (h == NULL) { 176 return MARISA_HANDLE_ERROR; 177 } else if (key_id == NULL) { 178 return MARISA_PARAM_ERROR; 179 } 180 if (length == MARISA_ZERO_TERMINATED) { 181 *key_id = h->trie.lookup(ptr); 182 } else { 183 *key_id = h->trie.lookup(ptr, length); 184 } 185 return MARISA_OK; 186 } 187 188 marisa_status marisa_find(const marisa_trie *h, 189 const char *ptr, size_t length, 190 marisa_uint32 *key_ids, size_t *key_lengths, 191 size_t max_num_results, size_t *num_results) { 192 if (h == NULL) { 193 return MARISA_HANDLE_ERROR; 194 } else if (num_results == NULL) { 195 return MARISA_PARAM_ERROR; 196 } 197 if (length == MARISA_ZERO_TERMINATED) { 198 *num_results = h->trie.find(ptr, key_ids, key_lengths, max_num_results); 199 } else { 200 *num_results = h->trie.find(ptr, length, 201 key_ids, key_lengths, max_num_results); 202 } 203 return MARISA_OK; 204 } 205 206 marisa_status marisa_find_first(const marisa_trie *h, 207 const char *ptr, size_t length, 208 marisa_uint32 *key_id, size_t *key_length) { 209 if (h == NULL) { 210 return MARISA_HANDLE_ERROR; 211 } else if (key_id == NULL) { 212 return MARISA_PARAM_ERROR; 213 } 214 if (length == MARISA_ZERO_TERMINATED) { 215 *key_id = h->trie.find_first(ptr, key_length); 216 } else { 217 *key_id = h->trie.find_first(ptr, length, key_length); 218 } 219 return MARISA_OK; 220 } 221 222 marisa_status marisa_find_last(const marisa_trie *h, 223 const char *ptr, size_t length, 224 marisa_uint32 *key_id, size_t *key_length) { 225 if (h == NULL) { 226 return MARISA_HANDLE_ERROR; 227 } else if (key_id == NULL) { 228 return MARISA_PARAM_ERROR; 229 } 230 if (length == MARISA_ZERO_TERMINATED) { 231 *key_id = h->trie.find_last(ptr, key_length); 232 } else { 233 *key_id = h->trie.find_last(ptr, length, key_length); 234 } 235 return MARISA_OK; 236 } 237 238 marisa_status marisa_find_callback(const marisa_trie *h, 239 const char *ptr, size_t length, 240 int (*callback)(void *, marisa_uint32, size_t), 241 void *first_arg_to_callback) { 242 if (h == NULL) { 243 return MARISA_HANDLE_ERROR; 244 } else if (callback == NULL) { 245 return MARISA_PARAM_ERROR; 246 } 247 if (length == MARISA_ZERO_TERMINATED) { 248 h->trie.find_callback(ptr, 249 ::FindCallback(callback, first_arg_to_callback)); 250 } else { 251 h->trie.find_callback(ptr, length, 252 ::FindCallback(callback, first_arg_to_callback)); 253 } 254 return MARISA_OK; 255 } 256 257 marisa_status marisa_predict(const marisa_trie *h, 258 const char *ptr, size_t length, marisa_uint32 *key_ids, 259 size_t max_num_results, size_t *num_results) { 260 return marisa_predict_breadth_first(h, ptr, length, 261 key_ids, max_num_results, num_results); 262 } 263 264 marisa_status marisa_predict_breadth_first(const marisa_trie *h, 265 const char *ptr, size_t length, marisa_uint32 *key_ids, 266 size_t max_num_results, size_t *num_results) { 267 if (h == NULL) { 268 return MARISA_HANDLE_ERROR; 269 } else if (num_results == NULL) { 270 return MARISA_PARAM_ERROR; 271 } 272 if (length == MARISA_ZERO_TERMINATED) { 273 *num_results = h->trie.predict_breadth_first( 274 ptr, key_ids, NULL, max_num_results); 275 } else { 276 *num_results = h->trie.predict_breadth_first( 277 ptr, length, key_ids, NULL, max_num_results); 278 } 279 return MARISA_OK; 280 } 281 282 marisa_status marisa_predict_depth_first(const marisa_trie *h, 283 const char *ptr, size_t length, marisa_uint32 *key_ids, 284 size_t max_num_results, size_t *num_results) { 285 if (h == NULL) { 286 return MARISA_HANDLE_ERROR; 287 } else if (num_results == NULL) { 288 return MARISA_PARAM_ERROR; 289 } 290 if (length == MARISA_ZERO_TERMINATED) { 291 *num_results = h->trie.predict_depth_first( 292 ptr, key_ids, NULL, max_num_results); 293 } else { 294 *num_results = h->trie.predict_depth_first( 295 ptr, length, key_ids, NULL, max_num_results); 296 } 297 return MARISA_OK; 298 } 299 300 marisa_status marisa_predict_callback(const marisa_trie *h, 301 const char *ptr, size_t length, 302 int (*callback)(void *, marisa_uint32, const char *, size_t), 303 void *first_arg_to_callback) { 304 if (h == NULL) { 305 return MARISA_HANDLE_ERROR; 306 } else if (callback == NULL) { 307 return MARISA_PARAM_ERROR; 308 } 309 if (length == MARISA_ZERO_TERMINATED) { 310 h->trie.predict_callback(ptr, 311 ::PredictCallback(callback, first_arg_to_callback)); 312 } else { 313 h->trie.predict_callback(ptr, length, 314 ::PredictCallback(callback, first_arg_to_callback)); 315 } 316 return MARISA_OK; 317 } 318 319 size_t marisa_get_num_tries(const marisa_trie *h) { 320 return (h != NULL) ? h->trie.num_tries() : 0; 321 } 322 323 size_t marisa_get_num_keys(const marisa_trie *h) { 324 return (h != NULL) ? h->trie.num_keys() : 0; 325 } 326 327 size_t marisa_get_num_nodes(const marisa_trie *h) { 328 return (h != NULL) ? h->trie.num_nodes() : 0; 329 } 330 331 size_t marisa_get_total_size(const marisa_trie *h) { 332 return (h != NULL) ? h->trie.total_size() : 0; 333 } 334 335 marisa_status marisa_clear(marisa_trie *h) { 336 if (h == NULL) { 337 return MARISA_HANDLE_ERROR; 338 } 339 h->trie.clear(); 340 h->mapper.clear(); 341 return MARISA_OK; 342 } 343 344 } // extern "C" 345