1 #include <algorithm> 2 #include <functional> 3 #include <queue> 4 #include <stdexcept> 5 6 #include "range.h" 7 #include "trie.h" 8 9 namespace marisa_alpha { 10 11 void Trie::build(const char * const *keys, std::size_t num_keys, 12 const std::size_t *key_lengths, const double *key_weights, 13 UInt32 *key_ids, int flags) { 14 MARISA_ALPHA_THROW_IF((keys == NULL) && (num_keys != 0), 15 MARISA_ALPHA_PARAM_ERROR); 16 Vector<Key<String> > temp_keys; 17 temp_keys.resize(num_keys); 18 for (std::size_t i = 0; i < temp_keys.size(); ++i) { 19 MARISA_ALPHA_THROW_IF(keys[i] == NULL, MARISA_ALPHA_PARAM_ERROR); 20 std::size_t length = 0; 21 if (key_lengths == NULL) { 22 while (keys[i][length] != '\0') { 23 ++length; 24 } 25 } else { 26 length = key_lengths[i]; 27 } 28 MARISA_ALPHA_THROW_IF(length > MARISA_ALPHA_MAX_LENGTH, 29 MARISA_ALPHA_SIZE_ERROR); 30 temp_keys[i].set_str(String(keys[i], length)); 31 temp_keys[i].set_weight((key_weights != NULL) ? key_weights[i] : 1.0); 32 } 33 build_trie(temp_keys, key_ids, flags); 34 } 35 36 void Trie::build(const std::vector<std::string> &keys, 37 std::vector<UInt32> *key_ids, int flags) { 38 Vector<Key<String> > temp_keys; 39 temp_keys.resize(keys.size()); 40 for (std::size_t i = 0; i < temp_keys.size(); ++i) { 41 MARISA_ALPHA_THROW_IF(keys[i].length() > MARISA_ALPHA_MAX_LENGTH, 42 MARISA_ALPHA_SIZE_ERROR); 43 temp_keys[i].set_str(String(keys[i].c_str(), keys[i].length())); 44 temp_keys[i].set_weight(1.0); 45 } 46 build_trie(temp_keys, key_ids, flags); 47 } 48 49 void Trie::build(const std::vector<std::pair<std::string, double> > &keys, 50 std::vector<UInt32> *key_ids, int flags) { 51 Vector<Key<String> > temp_keys; 52 temp_keys.resize(keys.size()); 53 for (std::size_t i = 0; i < temp_keys.size(); ++i) { 54 MARISA_ALPHA_THROW_IF(keys[i].first.length() > MARISA_ALPHA_MAX_LENGTH, 55 MARISA_ALPHA_SIZE_ERROR); 56 temp_keys[i].set_str(String( 57 keys[i].first.c_str(), keys[i].first.length())); 58 temp_keys[i].set_weight(keys[i].second); 59 } 60 build_trie(temp_keys, key_ids, flags); 61 } 62 63 void Trie::build_trie(Vector<Key<String> > &keys, 64 std::vector<UInt32> *key_ids, int flags) { 65 if (key_ids == NULL) { 66 build_trie(keys, static_cast<UInt32 *>(NULL), flags); 67 return; 68 } 69 try { 70 std::vector<UInt32> temp_key_ids(keys.size()); 71 build_trie(keys, temp_key_ids.empty() ? NULL : &temp_key_ids[0], flags); 72 key_ids->swap(temp_key_ids); 73 } catch (const std::bad_alloc &) { 74 MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR); 75 } catch (const std::length_error &) { 76 MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR); 77 } 78 } 79 80 void Trie::build_trie(Vector<Key<String> > &keys, 81 UInt32 *key_ids, int flags) { 82 Trie temp; 83 Vector<UInt32> terminals; 84 Progress progress(flags); 85 MARISA_ALPHA_THROW_IF(!progress.is_valid(), MARISA_ALPHA_PARAM_ERROR); 86 temp.build_trie(keys, &terminals, progress); 87 88 typedef std::pair<UInt32, UInt32> TerminalIdPair; 89 Vector<TerminalIdPair> pairs; 90 pairs.resize(terminals.size()); 91 for (UInt32 i = 0; i < pairs.size(); ++i) { 92 pairs[i].first = terminals[i]; 93 pairs[i].second = i; 94 } 95 terminals.clear(); 96 std::sort(pairs.begin(), pairs.end()); 97 98 UInt32 node = 0; 99 for (UInt32 i = 0; i < pairs.size(); ++i) { 100 while (node < pairs[i].first) { 101 temp.terminal_flags_.push_back(false); 102 ++node; 103 } 104 if (node == pairs[i].first) { 105 temp.terminal_flags_.push_back(true); 106 ++node; 107 } 108 } 109 while (node < temp.labels_.size()) { 110 temp.terminal_flags_.push_back(false); 111 ++node; 112 } 113 terminal_flags_.push_back(false); 114 temp.terminal_flags_.build(); 115 temp.terminal_flags_.clear_select0s(); 116 progress.test_total_size(temp.terminal_flags_.total_size()); 117 118 if (key_ids != NULL) { 119 for (UInt32 i = 0; i < pairs.size(); ++i) { 120 key_ids[pairs[i].second] = temp.node_to_key_id(pairs[i].first); 121 } 122 } 123 MARISA_ALPHA_THROW_IF(progress.total_size() != temp.total_size(), 124 MARISA_ALPHA_UNEXPECTED_ERROR); 125 temp.swap(this); 126 } 127 128 template <typename T> 129 void Trie::build_trie(Vector<Key<T> > &keys, 130 Vector<UInt32> *terminals, Progress &progress) { 131 build_cur(keys, terminals, progress); 132 progress.test_total_size(louds_.total_size()); 133 progress.test_total_size(sizeof(num_first_branches_)); 134 progress.test_total_size(sizeof(num_keys_)); 135 if (link_flags_.empty()) { 136 labels_.shrink(); 137 progress.test_total_size(labels_.total_size()); 138 progress.test_total_size(link_flags_.total_size()); 139 progress.test_total_size(links_.total_size()); 140 progress.test_total_size(tail_.total_size()); 141 return; 142 } 143 144 Vector<UInt32> next_terminals; 145 build_next(keys, &next_terminals, progress); 146 147 if (has_trie()) { 148 progress.test_total_size(trie_->terminal_flags_.total_size()); 149 } else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) { 150 labels_.push_back('\0'); 151 link_flags_.push_back(true); 152 } 153 link_flags_.build(); 154 155 for (UInt32 i = 0; i < next_terminals.size(); ++i) { 156 labels_[link_flags_.select1(i)] = (UInt8)(next_terminals[i] % 256); 157 next_terminals[i] /= 256; 158 } 159 link_flags_.clear_select0s(); 160 if (has_trie() || (tail_.mode() == MARISA_ALPHA_TEXT_TAIL)) { 161 link_flags_.clear_select1s(); 162 } 163 164 links_.build(next_terminals); 165 labels_.shrink(); 166 progress.test_total_size(labels_.total_size()); 167 progress.test_total_size(link_flags_.total_size()); 168 progress.test_total_size(links_.total_size()); 169 progress.test_total_size(tail_.total_size()); 170 } 171 172 template <typename T> 173 void Trie::build_cur(Vector<Key<T> > &keys, 174 Vector<UInt32> *terminals, Progress &progress) try { 175 num_keys_ = sort_keys(keys); 176 louds_.push_back(true); 177 louds_.push_back(false); 178 labels_.push_back('\0'); 179 link_flags_.push_back(false); 180 181 Vector<Key<T> > rest_keys; 182 std::queue<Range> queue; 183 Vector<WRange> wranges; 184 queue.push(Range(0, (UInt32)keys.size(), 0)); 185 while (!queue.empty()) { 186 const UInt32 node = (UInt32)(link_flags_.size() - queue.size()); 187 Range range = queue.front(); 188 queue.pop(); 189 190 while ((range.begin() < range.end()) && 191 (keys[range.begin()].str().length() == range.pos())) { 192 keys[range.begin()].set_terminal(node); 193 range.set_begin(range.begin() + 1); 194 } 195 if (range.begin() == range.end()) { 196 louds_.push_back(false); 197 continue; 198 } 199 200 wranges.clear(); 201 double weight = keys[range.begin()].weight(); 202 for (UInt32 i = range.begin() + 1; i < range.end(); ++i) { 203 if (keys[i - 1].str()[range.pos()] != keys[i].str()[range.pos()]) { 204 wranges.push_back(WRange(range.begin(), i, range.pos(), weight)); 205 range.set_begin(i); 206 weight = 0.0; 207 } 208 weight += keys[i].weight(); 209 } 210 wranges.push_back(WRange(range, weight)); 211 if (progress.order() == MARISA_ALPHA_WEIGHT_ORDER) { 212 std::stable_sort(wranges.begin(), wranges.end(), std::greater<WRange>()); 213 } 214 if (node == 0) { 215 num_first_branches_ = wranges.size(); 216 } 217 for (UInt32 i = 0; i < wranges.size(); ++i) { 218 const WRange &wrange = wranges[i]; 219 UInt32 pos = wrange.pos() + 1; 220 if ((progress.tail() != MARISA_ALPHA_WITHOUT_TAIL) || 221 !progress.is_last()) { 222 while (pos < keys[wrange.begin()].str().length()) { 223 UInt32 j; 224 for (j = wrange.begin() + 1; j < wrange.end(); ++j) { 225 if (keys[j - 1].str()[pos] != keys[j].str()[pos]) { 226 break; 227 } 228 } 229 if (j < wrange.end()) { 230 break; 231 } 232 ++pos; 233 } 234 } 235 if ((progress.trie() != MARISA_ALPHA_PATRICIA_TRIE) && 236 (pos != keys[wrange.end() - 1].str().length())) { 237 pos = wrange.pos() + 1; 238 } 239 louds_.push_back(true); 240 if (pos == wrange.pos() + 1) { 241 labels_.push_back(keys[wrange.begin()].str()[wrange.pos()]); 242 link_flags_.push_back(false); 243 } else { 244 labels_.push_back('\0'); 245 link_flags_.push_back(true); 246 Key<T> rest_key; 247 rest_key.set_str(keys[wrange.begin()].str().substr( 248 wrange.pos(), pos - wrange.pos())); 249 rest_key.set_weight(wrange.weight()); 250 rest_keys.push_back(rest_key); 251 } 252 wranges[i].set_pos(pos); 253 queue.push(wranges[i].range()); 254 } 255 louds_.push_back(false); 256 } 257 louds_.push_back(false); 258 louds_.build(); 259 if (progress.trie_id() != 0) { 260 louds_.clear_select0s(); 261 } 262 if (rest_keys.empty()) { 263 link_flags_.clear(); 264 } 265 266 build_terminals(keys, terminals); 267 keys.swap(&rest_keys); 268 } catch (const std::bad_alloc &) { 269 MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR); 270 } catch (const std::length_error &) { 271 MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR); 272 } 273 274 void Trie::build_next(Vector<Key<String> > &keys, 275 Vector<UInt32> *terminals, Progress &progress) { 276 if (progress.is_last()) { 277 Vector<String> strs; 278 strs.resize(keys.size()); 279 for (UInt32 i = 0; i < strs.size(); ++i) { 280 strs[i] = keys[i].str(); 281 } 282 tail_.build(strs, terminals, progress.tail()); 283 return; 284 } 285 Vector<Key<RString> > rkeys; 286 rkeys.resize(keys.size()); 287 for (UInt32 i = 0; i < rkeys.size(); ++i) { 288 rkeys[i].set_str(RString(keys[i].str())); 289 rkeys[i].set_weight(keys[i].weight()); 290 } 291 keys.clear(); 292 trie_.reset(new (std::nothrow) Trie); 293 MARISA_ALPHA_THROW_IF(!has_trie(), MARISA_ALPHA_MEMORY_ERROR); 294 trie_->build_trie(rkeys, terminals, ++progress); 295 } 296 297 void Trie::build_next(Vector<Key<RString> > &rkeys, 298 Vector<UInt32> *terminals, Progress &progress) { 299 if (progress.is_last()) { 300 Vector<String> strs; 301 strs.resize(rkeys.size()); 302 for (UInt32 i = 0; i < strs.size(); ++i) { 303 strs[i] = String(rkeys[i].str().ptr(), rkeys[i].str().length()); 304 } 305 tail_.build(strs, terminals, progress.tail()); 306 return; 307 } 308 trie_.reset(new (std::nothrow) Trie); 309 MARISA_ALPHA_THROW_IF(!has_trie(), MARISA_ALPHA_MEMORY_ERROR); 310 trie_->build_trie(rkeys, terminals, ++progress); 311 } 312 313 template <typename T> 314 UInt32 Trie::sort_keys(Vector<Key<T> > &keys) const { 315 if (keys.empty()) { 316 return 0; 317 } 318 for (UInt32 i = 0; i < keys.size(); ++i) { 319 keys[i].set_id(i); 320 } 321 std::sort(keys.begin(), keys.end()); 322 UInt32 count = 1; 323 for (UInt32 i = 1; i < keys.size(); ++i) { 324 if (keys[i - 1].str() != keys[i].str()) { 325 ++count; 326 } 327 } 328 return count; 329 } 330 331 template <typename T> 332 void Trie::build_terminals(const Vector<Key<T> > &keys, 333 Vector<UInt32> *terminals) const { 334 Vector<UInt32> temp_terminals; 335 temp_terminals.resize(keys.size()); 336 for (UInt32 i = 0; i < keys.size(); ++i) { 337 temp_terminals[keys[i].id()] = keys[i].terminal(); 338 } 339 temp_terminals.swap(terminals); 340 } 341 342 } // namespace marisa_alpha 343