1 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 // 14 // Copyright 2005-2010 Google, Inc. 15 // Author: sorenj (at) google.com (Jeffrey Sorensen) 16 // 17 #ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ 18 #define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ 19 20 #include <stddef.h> 21 #include <string.h> 22 #include <algorithm> 23 #include <string> 24 #include <vector> 25 using std::vector; 26 27 #include <fst/compat.h> 28 #include <fst/fstlib.h> 29 #include <fst/mapped-file.h> 30 #include <fst/extensions/ngram/bitmap-index.h> 31 32 // NgramFst implements a n-gram language model based upon the LOUDS data 33 // structure. Please refer to "Unary Data Strucutres for Language Models" 34 // http://research.google.com/pubs/archive/37218.pdf 35 36 namespace fst { 37 template <class A> class NGramFst; 38 template <class A> class NGramFstMatcher; 39 40 // Instance data containing mutable state for bookkeeping repeated access to 41 // the same state. 42 template <class A> 43 struct NGramFstInst { 44 typedef typename A::Label Label; 45 typedef typename A::StateId StateId; 46 typedef typename A::Weight Weight; 47 StateId state_; 48 size_t num_futures_; 49 size_t offset_; 50 size_t node_; 51 StateId node_state_; 52 vector<Label> context_; 53 StateId context_state_; 54 NGramFstInst() 55 : state_(kNoStateId), node_state_(kNoStateId), 56 context_state_(kNoStateId) { } 57 }; 58 59 // Implementation class for LOUDS based NgramFst interface 60 template <class A> 61 class NGramFstImpl : public FstImpl<A> { 62 using FstImpl<A>::SetInputSymbols; 63 using FstImpl<A>::SetOutputSymbols; 64 using FstImpl<A>::SetType; 65 using FstImpl<A>::WriteHeader; 66 67 friend class ArcIterator<NGramFst<A> >; 68 friend class NGramFstMatcher<A>; 69 70 public: 71 using FstImpl<A>::InputSymbols; 72 using FstImpl<A>::SetProperties; 73 using FstImpl<A>::Properties; 74 75 typedef A Arc; 76 typedef typename A::Label Label; 77 typedef typename A::StateId StateId; 78 typedef typename A::Weight Weight; 79 80 NGramFstImpl() : data_region_(0), data_(0), owned_(false) { 81 SetType("ngram"); 82 SetInputSymbols(NULL); 83 SetOutputSymbols(NULL); 84 SetProperties(kStaticProperties); 85 } 86 87 NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out); 88 89 ~NGramFstImpl() { 90 if (owned_) { 91 delete [] data_; 92 } 93 delete data_region_; 94 } 95 96 static NGramFstImpl<A>* Read(istream &strm, // NOLINT 97 const FstReadOptions &opts) { 98 NGramFstImpl<A>* impl = new NGramFstImpl(); 99 FstHeader hdr; 100 if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0; 101 uint64 num_states, num_futures, num_final; 102 const size_t offset = sizeof(num_states) + sizeof(num_futures) + 103 sizeof(num_final); 104 // Peek at num_states and num_futures to see how much more needs to be read. 105 strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states)); 106 strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures)); 107 strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final)); 108 size_t size = Storage(num_states, num_futures, num_final); 109 MappedFile *data_region = MappedFile::Allocate(size); 110 char *data = reinterpret_cast<char *>(data_region->mutable_data()); 111 // Copy num_states, num_futures and num_final back into data. 112 memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states)); 113 memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures), 114 sizeof(num_futures)); 115 memcpy(data + sizeof(num_states) + sizeof(num_futures), 116 reinterpret_cast<char *>(&num_final), sizeof(num_final)); 117 strm.read(data + offset, size - offset); 118 if (!strm) { 119 delete impl; 120 return NULL; 121 } 122 impl->Init(data, false, data_region); 123 return impl; 124 } 125 126 bool Write(ostream &strm, // NOLINT 127 const FstWriteOptions &opts) const { 128 FstHeader hdr; 129 hdr.SetStart(Start()); 130 hdr.SetNumStates(num_states_); 131 WriteHeader(strm, opts, kFileVersion, &hdr); 132 strm.write(data_, StorageSize()); 133 return strm; 134 } 135 136 StateId Start() const { 137 return 1; 138 } 139 140 Weight Final(StateId state) const { 141 if (final_index_.Get(state)) { 142 return final_probs_[final_index_.Rank1(state)]; 143 } else { 144 return Weight::Zero(); 145 } 146 } 147 148 size_t NumArcs(StateId state, NGramFstInst<A> *inst = NULL) const { 149 if (inst == NULL) { 150 const size_t next_zero = future_index_.Select0(state + 1); 151 const size_t this_zero = future_index_.Select0(state); 152 return next_zero - this_zero - 1; 153 } 154 SetInstFuture(state, inst); 155 return inst->num_futures_ + ((state == 0) ? 0 : 1); 156 } 157 158 size_t NumInputEpsilons(StateId state) const { 159 // State 0 has no parent, thus no backoff. 160 if (state == 0) return 0; 161 return 1; 162 } 163 164 size_t NumOutputEpsilons(StateId state) const { 165 return NumInputEpsilons(state); 166 } 167 168 StateId NumStates() const { 169 return num_states_; 170 } 171 172 void InitStateIterator(StateIteratorData<A>* data) const { 173 data->base = 0; 174 data->nstates = num_states_; 175 } 176 177 static size_t Storage(uint64 num_states, uint64 num_futures, 178 uint64 num_final) { 179 uint64 b64; 180 Weight weight; 181 Label label; 182 size_t offset = sizeof(num_states) + sizeof(num_futures) + 183 sizeof(num_final); 184 offset += sizeof(b64) * ( 185 BitmapIndex::StorageSize(num_states * 2 + 1) + 186 BitmapIndex::StorageSize(num_futures + num_states + 1) + 187 BitmapIndex::StorageSize(num_states)); 188 offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label); 189 // Pad for alignemnt, see 190 // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding 191 offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1); 192 offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) + 193 (num_futures + 1) * sizeof(weight); 194 return offset; 195 } 196 197 void SetInstFuture(StateId state, NGramFstInst<A> *inst) const { 198 if (inst->state_ != state) { 199 inst->state_ = state; 200 const size_t next_zero = future_index_.Select0(state + 1); 201 const size_t this_zero = future_index_.Select0(state); 202 inst->num_futures_ = next_zero - this_zero - 1; 203 inst->offset_ = future_index_.Rank1(future_index_.Select0(state) + 1); 204 } 205 } 206 207 void SetInstNode(NGramFstInst<A> *inst) const { 208 if (inst->node_state_ != inst->state_) { 209 inst->node_state_ = inst->state_; 210 inst->node_ = context_index_.Select1(inst->state_); 211 } 212 } 213 214 void SetInstContext(NGramFstInst<A> *inst) const { 215 SetInstNode(inst); 216 if (inst->context_state_ != inst->state_) { 217 inst->context_state_ = inst->state_; 218 inst->context_.clear(); 219 size_t node = inst->node_; 220 while (node != 0) { 221 inst->context_.push_back(context_words_[context_index_.Rank1(node)]); 222 node = context_index_.Select1(context_index_.Rank0(node) - 1); 223 } 224 } 225 } 226 227 // Access to the underlying representation 228 const char* GetData(size_t* data_size) const { 229 *data_size = StorageSize(); 230 return data_; 231 } 232 233 void Init(const char* data, bool owned, MappedFile *file = 0); 234 235 const vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const { 236 SetInstFuture(s, inst); 237 SetInstContext(inst); 238 return inst->context_; 239 } 240 241 size_t StorageSize() const { 242 return Storage(num_states_, num_futures_, num_final_); 243 } 244 245 void GetStates(const vector<Label>& context, vector<StateId> *states) const; 246 247 private: 248 StateId Transition(const vector<Label> &context, Label future) const; 249 250 // Properties always true for this Fst class. 251 static const uint64 kStaticProperties = kAcceptor | kIDeterministic | 252 kODeterministic | kEpsilons | kIEpsilons | kOEpsilons | kILabelSorted | 253 kOLabelSorted | kWeighted | kCyclic | kInitialAcyclic | kNotTopSorted | 254 kAccessible | kCoAccessible | kNotString | kExpanded; 255 // Current file format version. 256 static const int kFileVersion = 4; 257 // Minimum file format version supported. 258 static const int kMinFileVersion = 4; 259 260 MappedFile *data_region_; 261 const char* data_; 262 bool owned_; // True if we own data_ 263 uint64 num_states_, num_futures_, num_final_; 264 size_t root_num_children_; 265 const Label *root_children_; 266 size_t root_first_child_; 267 // borrowed references 268 const uint64 *context_, *future_, *final_; 269 const Label *context_words_, *future_words_; 270 const Weight *backoff_, *final_probs_, *future_probs_; 271 BitmapIndex context_index_; 272 BitmapIndex future_index_; 273 BitmapIndex final_index_; 274 275 void operator=(const NGramFstImpl<A> &); // Disallow 276 }; 277 278 template<typename A> 279 NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) 280 : data_region_(0), data_(0), owned_(false) { 281 typedef A Arc; 282 typedef typename Arc::Label Label; 283 typedef typename Arc::Weight Weight; 284 typedef typename Arc::StateId StateId; 285 SetType("ngram"); 286 SetInputSymbols(fst.InputSymbols()); 287 SetOutputSymbols(fst.OutputSymbols()); 288 SetProperties(kStaticProperties); 289 290 // Check basic requirements for an OpenGRM language model Fst. 291 int64 props = kAcceptor | kIDeterministic | kIEpsilons | kILabelSorted; 292 if (fst.Properties(props, true) != props) { 293 FSTERROR() << "NGramFst only accepts OpenGRM langauge models as input"; 294 SetProperties(kError, kError); 295 return; 296 } 297 298 int64 num_states = CountStates(fst); 299 Label* context = new Label[num_states]; 300 301 // Find the unigram state by starting from the start state, following 302 // epsilons. 303 StateId unigram = fst.Start(); 304 while (1) { 305 if (unigram == kNoStateId) { 306 FSTERROR() << "Could not identify unigram state."; 307 SetProperties(kError, kError); 308 return; 309 } 310 ArcIterator<Fst<A> > aiter(fst, unigram); 311 if (aiter.Done()) { 312 LOG(WARNING) << "Unigram state " << unigram << " has no arcs."; 313 break; 314 } 315 if (aiter.Value().ilabel != 0) break; 316 unigram = aiter.Value().nextstate; 317 } 318 319 // Each state's context is determined by the subtree it is under from the 320 // unigram state. 321 queue<pair<StateId, Label> > label_queue; 322 vector<bool> visited(num_states); 323 // Force an epsilon link to the start state. 324 label_queue.push(make_pair(fst.Start(), 0)); 325 for (ArcIterator<Fst<A> > aiter(fst, unigram); 326 !aiter.Done(); aiter.Next()) { 327 label_queue.push(make_pair(aiter.Value().nextstate, aiter.Value().ilabel)); 328 } 329 // investigate states in breadth first fashion to assign context words. 330 while (!label_queue.empty()) { 331 pair<StateId, Label> &now = label_queue.front(); 332 if (!visited[now.first]) { 333 context[now.first] = now.second; 334 visited[now.first] = true; 335 for (ArcIterator<Fst<A> > aiter(fst, now.first); 336 !aiter.Done(); aiter.Next()) { 337 const Arc &arc = aiter.Value(); 338 if (arc.ilabel != 0) { 339 label_queue.push(make_pair(arc.nextstate, now.second)); 340 } 341 } 342 } 343 label_queue.pop(); 344 } 345 visited.clear(); 346 347 // The arc from the start state should be assigned an epsilon to put it 348 // in front of the all other labels (which makes Start state 1 after 349 // unigram which is state 0). 350 context[fst.Start()] = 0; 351 352 // Build the tree of contexts fst by reversing the epsilon arcs from fst. 353 VectorFst<Arc> context_fst; 354 uint64 num_final = 0; 355 for (int i = 0; i < num_states; ++i) { 356 if (fst.Final(i) != Weight::Zero()) { 357 ++num_final; 358 } 359 context_fst.SetFinal(context_fst.AddState(), fst.Final(i)); 360 } 361 context_fst.SetStart(unigram); 362 context_fst.SetInputSymbols(fst.InputSymbols()); 363 context_fst.SetOutputSymbols(fst.OutputSymbols()); 364 int64 num_context_arcs = 0; 365 int64 num_futures = 0; 366 for (StateIterator<Fst<A> > siter(fst); !siter.Done(); siter.Next()) { 367 const StateId &state = siter.Value(); 368 num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state); 369 ArcIterator<Fst<A> > aiter(fst, state); 370 if (!aiter.Done()) { 371 const Arc &arc = aiter.Value(); 372 // this arc goes from state to arc.nextstate, so create an arc from 373 // arc.nextstate to state to reverse it. 374 if (arc.ilabel == 0) { 375 context_fst.AddArc(arc.nextstate, Arc(context[state], context[state], 376 arc.weight, state)); 377 num_context_arcs++; 378 } 379 } 380 } 381 if (num_context_arcs != context_fst.NumStates() - 1) { 382 FSTERROR() << "Number of contexts arcs != number of states - 1"; 383 SetProperties(kError, kError); 384 return; 385 } 386 if (context_fst.NumStates() != num_states) { 387 FSTERROR() << "Number of contexts != number of states"; 388 SetProperties(kError, kError); 389 return; 390 } 391 int64 context_props = context_fst.Properties(kIDeterministic | 392 kILabelSorted, true); 393 if (!(context_props & kIDeterministic)) { 394 FSTERROR() << "Input fst is not structured properly"; 395 SetProperties(kError, kError); 396 return; 397 } 398 if (!(context_props & kILabelSorted)) { 399 ArcSort(&context_fst, ILabelCompare<Arc>()); 400 } 401 402 delete [] context; 403 404 uint64 b64; 405 Weight weight; 406 Label label = kNoLabel; 407 const size_t storage = Storage(num_states, num_futures, num_final); 408 MappedFile *data_region = MappedFile::Allocate(storage); 409 char *data = reinterpret_cast<char *>(data_region->mutable_data()); 410 memset(data, 0, storage); 411 size_t offset = 0; 412 memcpy(data + offset, reinterpret_cast<char *>(&num_states), 413 sizeof(num_states)); 414 offset += sizeof(num_states); 415 memcpy(data + offset, reinterpret_cast<char *>(&num_futures), 416 sizeof(num_futures)); 417 offset += sizeof(num_futures); 418 memcpy(data + offset, reinterpret_cast<char *>(&num_final), 419 sizeof(num_final)); 420 offset += sizeof(num_final); 421 uint64* context_bits = reinterpret_cast<uint64*>(data + offset); 422 offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64); 423 uint64* future_bits = reinterpret_cast<uint64*>(data + offset); 424 offset += 425 BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64); 426 uint64* final_bits = reinterpret_cast<uint64*>(data + offset); 427 offset += BitmapIndex::StorageSize(num_states) * sizeof(b64); 428 Label* context_words = reinterpret_cast<Label*>(data + offset); 429 offset += (num_states + 1) * sizeof(label); 430 Label* future_words = reinterpret_cast<Label*>(data + offset); 431 offset += num_futures * sizeof(label); 432 offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1); 433 Weight* backoff = reinterpret_cast<Weight*>(data + offset); 434 offset += (num_states + 1) * sizeof(weight); 435 Weight* final_probs = reinterpret_cast<Weight*>(data + offset); 436 offset += num_final * sizeof(weight); 437 Weight* future_probs = reinterpret_cast<Weight*>(data + offset); 438 int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0, 439 final_bit = 0; 440 441 // pseudo-root bits 442 BitmapIndex::Set(context_bits, context_bit++); 443 ++context_bit; 444 context_words[context_arc] = label; 445 backoff[context_arc] = Weight::Zero(); 446 context_arc++; 447 448 ++future_bit; 449 if (order_out) { 450 order_out->clear(); 451 order_out->resize(num_states); 452 } 453 454 queue<StateId> context_q; 455 context_q.push(context_fst.Start()); 456 StateId state_number = 0; 457 while (!context_q.empty()) { 458 const StateId &state = context_q.front(); 459 if (order_out) { 460 (*order_out)[state] = state_number; 461 } 462 463 const Weight &final = context_fst.Final(state); 464 if (final != Weight::Zero()) { 465 BitmapIndex::Set(final_bits, state_number); 466 final_probs[final_bit] = final; 467 ++final_bit; 468 } 469 470 for (ArcIterator<VectorFst<A> > aiter(context_fst, state); 471 !aiter.Done(); aiter.Next()) { 472 const Arc &arc = aiter.Value(); 473 context_words[context_arc] = arc.ilabel; 474 backoff[context_arc] = arc.weight; 475 ++context_arc; 476 BitmapIndex::Set(context_bits, context_bit++); 477 context_q.push(arc.nextstate); 478 } 479 ++context_bit; 480 481 for (ArcIterator<Fst<A> > aiter(fst, state); !aiter.Done(); aiter.Next()) { 482 const Arc &arc = aiter.Value(); 483 if (arc.ilabel != 0) { 484 future_words[future_arc] = arc.ilabel; 485 future_probs[future_arc] = arc.weight; 486 ++future_arc; 487 BitmapIndex::Set(future_bits, future_bit++); 488 } 489 } 490 ++future_bit; 491 ++state_number; 492 context_q.pop(); 493 } 494 495 if ((state_number != num_states) || 496 (context_bit != num_states * 2 + 1) || 497 (context_arc != num_states) || 498 (future_arc != num_futures) || 499 (future_bit != num_futures + num_states + 1) || 500 (final_bit != num_final)) { 501 FSTERROR() << "Structure problems detected during construction"; 502 SetProperties(kError, kError); 503 return; 504 } 505 506 Init(data, false, data_region); 507 } 508 509 template<typename A> 510 inline void NGramFstImpl<A>::Init(const char* data, bool owned, 511 MappedFile *data_region) { 512 if (owned_) { 513 delete [] data_; 514 } 515 delete data_region_; 516 data_region_ = data_region; 517 owned_ = owned; 518 data_ = data; 519 size_t offset = 0; 520 num_states_ = *(reinterpret_cast<const uint64*>(data_ + offset)); 521 offset += sizeof(num_states_); 522 num_futures_ = *(reinterpret_cast<const uint64*>(data_ + offset)); 523 offset += sizeof(num_futures_); 524 num_final_ = *(reinterpret_cast<const uint64*>(data_ + offset)); 525 offset += sizeof(num_final_); 526 uint64 bits; 527 size_t context_bits = num_states_ * 2 + 1; 528 size_t future_bits = num_futures_ + num_states_ + 1; 529 context_ = reinterpret_cast<const uint64*>(data_ + offset); 530 offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits); 531 future_ = reinterpret_cast<const uint64*>(data_ + offset); 532 offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits); 533 final_ = reinterpret_cast<const uint64*>(data_ + offset); 534 offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits); 535 context_words_ = reinterpret_cast<const Label*>(data_ + offset); 536 offset += (num_states_ + 1) * sizeof(*context_words_); 537 future_words_ = reinterpret_cast<const Label*>(data_ + offset); 538 offset += num_futures_ * sizeof(*future_words_); 539 offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1); 540 backoff_ = reinterpret_cast<const Weight*>(data_ + offset); 541 offset += (num_states_ + 1) * sizeof(*backoff_); 542 final_probs_ = reinterpret_cast<const Weight*>(data_ + offset); 543 offset += num_final_ * sizeof(*final_probs_); 544 future_probs_ = reinterpret_cast<const Weight*>(data_ + offset); 545 546 context_index_.BuildIndex(context_, context_bits); 547 future_index_.BuildIndex(future_, future_bits); 548 final_index_.BuildIndex(final_, num_states_); 549 550 const size_t node_rank = context_index_.Rank1(0); 551 root_first_child_ = context_index_.Select0(node_rank) + 1; 552 if (context_index_.Get(root_first_child_) == false) { 553 FSTERROR() << "Missing unigrams"; 554 SetProperties(kError, kError); 555 return; 556 } 557 const size_t last_child = context_index_.Select0(node_rank + 1) - 1; 558 root_num_children_ = last_child - root_first_child_ + 1; 559 root_children_ = context_words_ + context_index_.Rank1(root_first_child_); 560 } 561 562 template<typename A> 563 inline typename A::StateId NGramFstImpl<A>::Transition( 564 const vector<Label> &context, Label future) const { 565 const Label *children = root_children_; 566 const Label *loc = lower_bound(children, children + root_num_children_, 567 future); 568 if (loc == children + root_num_children_ || *loc != future) { 569 return context_index_.Rank1(0); 570 } 571 size_t node = root_first_child_ + loc - children; 572 size_t node_rank = context_index_.Rank1(node); 573 size_t first_child = context_index_.Select0(node_rank) + 1; 574 if (context_index_.Get(first_child) == false) { 575 return context_index_.Rank1(node); 576 } 577 size_t last_child = context_index_.Select0(node_rank + 1) - 1; 578 for (int word = context.size() - 1; word >= 0; --word) { 579 children = context_words_ + context_index_.Rank1(first_child); 580 loc = lower_bound(children, children + last_child - first_child + 1, 581 context[word]); 582 if (loc == children + last_child - first_child + 1 || 583 *loc != context[word]) { 584 break; 585 } 586 node = first_child + loc - children; 587 node_rank = context_index_.Rank1(node); 588 first_child = context_index_.Select0(node_rank) + 1; 589 if (context_index_.Get(first_child) == false) break; 590 last_child = context_index_.Select0(node_rank + 1) - 1; 591 } 592 return context_index_.Rank1(node); 593 } 594 595 template<typename A> 596 inline void NGramFstImpl<A>::GetStates( 597 const vector<Label> &context, 598 vector<typename A::StateId>* states) const { 599 states->clear(); 600 states->push_back(0); 601 typename vector<Label>::const_reverse_iterator cit = context.rbegin(); 602 const Label *children = root_children_; 603 const Label *loc = lower_bound(children, children + root_num_children_, *cit); 604 if (loc == children + root_num_children_ || *loc != *cit) return; 605 size_t node = root_first_child_ + loc - children; 606 states->push_back(context_index_.Rank1(node)); 607 if (context.size() == 1) return; 608 size_t node_rank = context_index_.Rank1(node); 609 size_t first_child = context_index_.Select0(node_rank) + 1; 610 ++cit; 611 if (context_index_.Get(first_child) != false) { 612 size_t last_child = context_index_.Select0(node_rank + 1) - 1; 613 while (cit != context.rend()) { 614 children = context_words_ + context_index_.Rank1(first_child); 615 loc = lower_bound(children, children + last_child - first_child + 1, 616 *cit); 617 if (loc == children + last_child - first_child + 1 || *loc != *cit) { 618 break; 619 } 620 ++cit; 621 node = first_child + loc - children; 622 states->push_back(context_index_.Rank1(node)); 623 node_rank = context_index_.Rank1(node); 624 first_child = context_index_.Select0(node_rank) + 1; 625 if (context_index_.Get(first_child) == false) break; 626 last_child = context_index_.Select0(node_rank + 1) - 1; 627 } 628 } 629 } 630 631 /*****************************************************************************/ 632 template<class A> 633 class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { 634 friend class ArcIterator<NGramFst<A> >; 635 friend class NGramFstMatcher<A>; 636 637 public: 638 typedef A Arc; 639 typedef typename A::StateId StateId; 640 typedef typename A::Label Label; 641 typedef typename A::Weight Weight; 642 typedef NGramFstImpl<A> Impl; 643 644 explicit NGramFst(const Fst<A> &dst) 645 : ImplToExpandedFst<Impl>(new Impl(dst, NULL)) {} 646 647 NGramFst(const Fst<A> &fst, vector<StateId>* order_out) 648 : ImplToExpandedFst<Impl>(new Impl(fst, order_out)) {} 649 650 // Because the NGramFstImpl is a const stateless data structure, there 651 // is never a need to do anything beside copy the reference. 652 NGramFst(const NGramFst<A> &fst, bool safe = false) 653 : ImplToExpandedFst<Impl>(fst, false) {} 654 655 NGramFst() : ImplToExpandedFst<Impl>(new Impl()) {} 656 657 // Non-standard constructor to initialize NGramFst directly from data. 658 NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) { 659 GetImpl()->Init(data, owned, NULL); 660 } 661 662 // Get method that gets the data associated with Init(). 663 const char* GetData(size_t* data_size) const { 664 return GetImpl()->GetData(data_size); 665 } 666 667 const vector<Label> GetContext(StateId s) const { 668 return GetImpl()->GetContext(s, &inst_); 669 } 670 671 // Consumes as much as possible of context from right to left, returns the 672 // the states corresponding to the increasingly conditioned input sequence. 673 void GetStates(const vector<Label>& context, vector<StateId> *state) const { 674 return GetImpl()->GetStates(context, state); 675 } 676 677 virtual size_t NumArcs(StateId s) const { 678 return GetImpl()->NumArcs(s, &inst_); 679 } 680 681 virtual NGramFst<A>* Copy(bool safe = false) const { 682 return new NGramFst(*this, safe); 683 } 684 685 static NGramFst<A>* Read(istream &strm, const FstReadOptions &opts) { 686 Impl* impl = Impl::Read(strm, opts); 687 return impl ? new NGramFst<A>(impl) : 0; 688 } 689 690 static NGramFst<A>* Read(const string &filename) { 691 if (!filename.empty()) { 692 ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); 693 if (!strm) { 694 LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename; 695 return 0; 696 } 697 return Read(strm, FstReadOptions(filename)); 698 } else { 699 return Read(cin, FstReadOptions("standard input")); 700 } 701 } 702 703 virtual bool Write(ostream &strm, const FstWriteOptions &opts) const { 704 return GetImpl()->Write(strm, opts); 705 } 706 707 virtual bool Write(const string &filename) const { 708 return Fst<A>::WriteFile(filename); 709 } 710 711 virtual inline void InitStateIterator(StateIteratorData<A>* data) const { 712 GetImpl()->InitStateIterator(data); 713 } 714 715 virtual inline void InitArcIterator( 716 StateId s, ArcIteratorData<A>* data) const; 717 718 virtual MatcherBase<A>* InitMatcher(MatchType match_type) const { 719 return new NGramFstMatcher<A>(*this, match_type); 720 } 721 722 size_t StorageSize() const { 723 return GetImpl()->StorageSize(); 724 } 725 726 private: 727 explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {} 728 729 Impl* GetImpl() const { 730 return 731 ImplToExpandedFst<Impl, ExpandedFst<A> >::GetImpl(); 732 } 733 734 void SetImpl(Impl* impl, bool own_impl = true) { 735 ImplToExpandedFst<Impl, Fst<A> >::SetImpl(impl, own_impl); 736 } 737 738 mutable NGramFstInst<A> inst_; 739 }; 740 741 template <class A> inline void 742 NGramFst<A>::InitArcIterator(StateId s, ArcIteratorData<A>* data) const { 743 GetImpl()->SetInstFuture(s, &inst_); 744 GetImpl()->SetInstNode(&inst_); 745 data->base = new ArcIterator<NGramFst<A> >(*this, s); 746 } 747 748 /*****************************************************************************/ 749 template <class A> 750 class NGramFstMatcher : public MatcherBase<A> { 751 public: 752 typedef A Arc; 753 typedef typename A::Label Label; 754 typedef typename A::StateId StateId; 755 typedef typename A::Weight Weight; 756 757 NGramFstMatcher(const NGramFst<A> &fst, MatchType match_type) 758 : fst_(fst), inst_(fst.inst_), match_type_(match_type), 759 current_loop_(false), 760 loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { 761 if (match_type_ == MATCH_OUTPUT) { 762 swap(loop_.ilabel, loop_.olabel); 763 } 764 } 765 766 NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false) 767 : fst_(matcher.fst_), inst_(matcher.inst_), 768 match_type_(matcher.match_type_), current_loop_(false), 769 loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { 770 if (match_type_ == MATCH_OUTPUT) { 771 swap(loop_.ilabel, loop_.olabel); 772 } 773 } 774 775 virtual NGramFstMatcher<A>* Copy(bool safe = false) const { 776 return new NGramFstMatcher<A>(*this, safe); 777 } 778 779 virtual MatchType Type(bool test) const { 780 return match_type_; 781 } 782 783 virtual const Fst<A> &GetFst() const { 784 return fst_; 785 } 786 787 virtual uint64 Properties(uint64 props) const { 788 return props; 789 } 790 791 private: 792 virtual void SetState_(StateId s) { 793 fst_.GetImpl()->SetInstFuture(s, &inst_); 794 current_loop_ = false; 795 } 796 797 virtual bool Find_(Label label) { 798 const Label nolabel = kNoLabel; 799 done_ = true; 800 if (label == 0 || label == nolabel) { 801 if (label == 0) { 802 current_loop_ = true; 803 loop_.nextstate = inst_.state_; 804 } 805 // The unigram state has no epsilon arc. 806 if (inst_.state_ != 0) { 807 arc_.ilabel = arc_.olabel = 0; 808 fst_.GetImpl()->SetInstNode(&inst_); 809 arc_.nextstate = fst_.GetImpl()->context_index_.Rank1( 810 fst_.GetImpl()->context_index_.Select1( 811 fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1)); 812 arc_.weight = fst_.GetImpl()->backoff_[inst_.state_]; 813 done_ = false; 814 } 815 } else { 816 const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_; 817 const Label *end = start + inst_.num_futures_; 818 const Label* search = lower_bound(start, end, label); 819 if (search != end && *search == label) { 820 size_t state = search - start; 821 arc_.ilabel = arc_.olabel = label; 822 arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state]; 823 fst_.GetImpl()->SetInstContext(&inst_); 824 arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label); 825 done_ = false; 826 } 827 } 828 return !Done_(); 829 } 830 831 virtual bool Done_() const { 832 return !current_loop_ && done_; 833 } 834 835 virtual const Arc& Value_() const { 836 return (current_loop_) ? loop_ : arc_; 837 } 838 839 virtual void Next_() { 840 if (current_loop_) { 841 current_loop_ = false; 842 } else { 843 done_ = true; 844 } 845 } 846 847 const NGramFst<A>& fst_; 848 NGramFstInst<A> inst_; 849 MatchType match_type_; // Supplied by caller 850 bool done_; 851 Arc arc_; 852 bool current_loop_; // Current arc is the implicit loop 853 Arc loop_; 854 }; 855 856 /*****************************************************************************/ 857 template<class A> 858 class ArcIterator<NGramFst<A> > : public ArcIteratorBase<A> { 859 public: 860 typedef A Arc; 861 typedef typename A::Label Label; 862 typedef typename A::StateId StateId; 863 typedef typename A::Weight Weight; 864 865 ArcIterator(const NGramFst<A> &fst, StateId state) 866 : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) { 867 inst_ = fst.inst_; 868 impl_->SetInstFuture(state, &inst_); 869 impl_->SetInstNode(&inst_); 870 } 871 872 bool Done() const { 873 return i_ >= ((inst_.node_ == 0) ? inst_.num_futures_ : 874 inst_.num_futures_ + 1); 875 } 876 877 const Arc &Value() const { 878 bool eps = (inst_.node_ != 0 && i_ == 0); 879 StateId state = (inst_.node_ == 0) ? i_ : i_ - 1; 880 if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) { 881 arc_.ilabel = 882 arc_.olabel = eps ? 0 : impl_->future_words_[inst_.offset_ + state]; 883 lazy_ &= ~(kArcILabelValue | kArcOLabelValue); 884 } 885 if (flags_ & lazy_ & kArcNextStateValue) { 886 if (eps) { 887 arc_.nextstate = impl_->context_index_.Rank1( 888 impl_->context_index_.Select1( 889 impl_->context_index_.Rank0(inst_.node_) - 1)); 890 } else { 891 if (lazy_ & kArcNextStateValue) { 892 impl_->SetInstContext(&inst_); // first time only. 893 } 894 arc_.nextstate = 895 impl_->Transition(inst_.context_, 896 impl_->future_words_[inst_.offset_ + state]); 897 } 898 lazy_ &= ~kArcNextStateValue; 899 } 900 if (flags_ & lazy_ & kArcWeightValue) { 901 arc_.weight = eps ? impl_->backoff_[inst_.state_] : 902 impl_->future_probs_[inst_.offset_ + state]; 903 lazy_ &= ~kArcWeightValue; 904 } 905 return arc_; 906 } 907 908 void Next() { 909 ++i_; 910 lazy_ = ~0; 911 } 912 913 size_t Position() const { return i_; } 914 915 void Reset() { 916 i_ = 0; 917 lazy_ = ~0; 918 } 919 920 void Seek(size_t a) { 921 if (i_ != a) { 922 i_ = a; 923 lazy_ = ~0; 924 } 925 } 926 927 uint32 Flags() const { 928 return flags_; 929 } 930 931 void SetFlags(uint32 f, uint32 m) { 932 flags_ &= ~m; 933 flags_ |= (f & kArcValueFlags); 934 } 935 936 private: 937 virtual bool Done_() const { return Done(); } 938 virtual const Arc& Value_() const { return Value(); } 939 virtual void Next_() { Next(); } 940 virtual size_t Position_() const { return Position(); } 941 virtual void Reset_() { Reset(); } 942 virtual void Seek_(size_t a) { Seek(a); } 943 uint32 Flags_() const { return Flags(); } 944 void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); } 945 946 mutable Arc arc_; 947 mutable uint32 lazy_; 948 const NGramFstImpl<A> *impl_; 949 mutable NGramFstInst<A> inst_; 950 951 size_t i_; 952 uint32 flags_; 953 954 DISALLOW_COPY_AND_ASSIGN(ArcIterator); 955 }; 956 957 /*****************************************************************************/ 958 // Specialization for NGramFst; see generic version in fst.h 959 // for sample usage (but use the ProdLmFst type!). This version 960 // should inline. 961 template <class A> 962 class StateIterator<NGramFst<A> > : public StateIteratorBase<A> { 963 public: 964 typedef typename A::StateId StateId; 965 966 explicit StateIterator(const NGramFst<A> &fst) 967 : s_(0), num_states_(fst.NumStates()) { } 968 969 bool Done() const { return s_ >= num_states_; } 970 StateId Value() const { return s_; } 971 void Next() { ++s_; } 972 void Reset() { s_ = 0; } 973 974 private: 975 virtual bool Done_() const { return Done(); } 976 virtual StateId Value_() const { return Value(); } 977 virtual void Next_() { Next(); } 978 virtual void Reset_() { Reset(); } 979 980 StateId s_, num_states_; 981 982 DISALLOW_COPY_AND_ASSIGN(StateIterator); 983 }; 984 } // namespace fst 985 #endif // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ 986