Home | History | Annotate | Download | only in ngram
      1 
      2 // Licensed under the Apache License, Version 2.0 (the "License");
      3 // you may not use this file except in compliance with the License.
      4 // You may obtain a copy of the License at
      5 //
      6 //     http://www.apache.org/licenses/LICENSE-2.0
      7 //
      8 // Unless required by applicable law or agreed to in writing, software
      9 // distributed under the License is distributed on an "AS IS" BASIS,
     10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     11 // See the License for the specific language governing permissions and
     12 // limitations under the License.
     13 //
     14 // Copyright 2005-2010 Google, Inc.
     15 // Author: sorenj (at) google.com (Jeffrey Sorensen)
     16 //
     17 #ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
     18 #define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
     19 
     20 #include <stddef.h>
     21 #include <string.h>
     22 #include <algorithm>
     23 #include <string>
     24 #include <vector>
     25 using std::vector;
     26 
     27 #include <fst/compat.h>
     28 #include <fst/fstlib.h>
     29 #include <fst/mapped-file.h>
     30 #include <fst/extensions/ngram/bitmap-index.h>
     31 
     32 // NgramFst implements a n-gram language model based upon the LOUDS data
     33 // structure.  Please refer to "Unary Data Strucutres for Language Models"
     34 // http://research.google.com/pubs/archive/37218.pdf
     35 
     36 namespace fst {
     37 template <class A> class NGramFst;
     38 template <class A> class NGramFstMatcher;
     39 
     40 // Instance data containing mutable state for bookkeeping repeated access to
     41 // the same state.
     42 template <class A>
     43 struct NGramFstInst {
     44   typedef typename A::Label Label;
     45   typedef typename A::StateId StateId;
     46   typedef typename A::Weight Weight;
     47   StateId state_;
     48   size_t num_futures_;
     49   size_t offset_;
     50   size_t node_;
     51   StateId node_state_;
     52   vector<Label> context_;
     53   StateId context_state_;
     54   NGramFstInst()
     55       : state_(kNoStateId), node_state_(kNoStateId),
     56         context_state_(kNoStateId) { }
     57 };
     58 
     59 // Implementation class for LOUDS based NgramFst interface
     60 template <class A>
     61 class NGramFstImpl : public FstImpl<A> {
     62   using FstImpl<A>::SetInputSymbols;
     63   using FstImpl<A>::SetOutputSymbols;
     64   using FstImpl<A>::SetType;
     65   using FstImpl<A>::WriteHeader;
     66 
     67   friend class ArcIterator<NGramFst<A> >;
     68   friend class NGramFstMatcher<A>;
     69 
     70  public:
     71   using FstImpl<A>::InputSymbols;
     72   using FstImpl<A>::SetProperties;
     73   using FstImpl<A>::Properties;
     74 
     75   typedef A Arc;
     76   typedef typename A::Label Label;
     77   typedef typename A::StateId StateId;
     78   typedef typename A::Weight Weight;
     79 
     80   NGramFstImpl() : data_region_(0), data_(0), owned_(false) {
     81     SetType("ngram");
     82     SetInputSymbols(NULL);
     83     SetOutputSymbols(NULL);
     84     SetProperties(kStaticProperties);
     85   }
     86 
     87   NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out);
     88 
     89   ~NGramFstImpl() {
     90     if (owned_) {
     91       delete [] data_;
     92     }
     93     delete data_region_;
     94   }
     95 
     96   static NGramFstImpl<A>* Read(istream &strm,  // NOLINT
     97                                const FstReadOptions &opts) {
     98     NGramFstImpl<A>* impl = new NGramFstImpl();
     99     FstHeader hdr;
    100     if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0;
    101     uint64 num_states, num_futures, num_final;
    102     const size_t offset = sizeof(num_states) + sizeof(num_futures) +
    103         sizeof(num_final);
    104     // Peek at num_states and num_futures to see how much more needs to be read.
    105     strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states));
    106     strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures));
    107     strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final));
    108     size_t size = Storage(num_states, num_futures, num_final);
    109     MappedFile *data_region = MappedFile::Allocate(size);
    110     char *data = reinterpret_cast<char *>(data_region->mutable_data());
    111     // Copy num_states, num_futures and num_final back into data.
    112     memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states));
    113     memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures),
    114            sizeof(num_futures));
    115     memcpy(data + sizeof(num_states) + sizeof(num_futures),
    116            reinterpret_cast<char *>(&num_final), sizeof(num_final));
    117     strm.read(data + offset, size - offset);
    118     if (!strm) {
    119       delete impl;
    120       return NULL;
    121     }
    122     impl->Init(data, false, data_region);
    123     return impl;
    124   }
    125 
    126   bool Write(ostream &strm,   // NOLINT
    127              const FstWriteOptions &opts) const {
    128     FstHeader hdr;
    129     hdr.SetStart(Start());
    130     hdr.SetNumStates(num_states_);
    131     WriteHeader(strm, opts, kFileVersion, &hdr);
    132     strm.write(data_, StorageSize());
    133     return strm;
    134   }
    135 
    136   StateId Start() const {
    137     return 1;
    138   }
    139 
    140   Weight Final(StateId state) const {
    141     if (final_index_.Get(state)) {
    142       return final_probs_[final_index_.Rank1(state)];
    143     } else {
    144       return Weight::Zero();
    145     }
    146   }
    147 
    148   size_t NumArcs(StateId state, NGramFstInst<A> *inst = NULL) const {
    149     if (inst == NULL) {
    150       const size_t next_zero = future_index_.Select0(state + 1);
    151       const size_t this_zero = future_index_.Select0(state);
    152       return next_zero - this_zero - 1;
    153     }
    154     SetInstFuture(state, inst);
    155     return inst->num_futures_ + ((state == 0) ? 0 : 1);
    156   }
    157 
    158   size_t NumInputEpsilons(StateId state) const {
    159     // State 0 has no parent, thus no backoff.
    160     if (state == 0) return 0;
    161     return 1;
    162   }
    163 
    164   size_t NumOutputEpsilons(StateId state) const {
    165     return NumInputEpsilons(state);
    166   }
    167 
    168   StateId NumStates() const {
    169     return num_states_;
    170   }
    171 
    172   void InitStateIterator(StateIteratorData<A>* data) const {
    173     data->base = 0;
    174     data->nstates = num_states_;
    175   }
    176 
    177   static size_t Storage(uint64 num_states, uint64 num_futures,
    178                         uint64 num_final) {
    179     uint64 b64;
    180     Weight weight;
    181     Label label;
    182     size_t offset = sizeof(num_states) + sizeof(num_futures) +
    183         sizeof(num_final);
    184     offset += sizeof(b64) * (
    185         BitmapIndex::StorageSize(num_states * 2 + 1) +
    186         BitmapIndex::StorageSize(num_futures + num_states + 1) +
    187         BitmapIndex::StorageSize(num_states));
    188     offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label);
    189     // Pad for alignemnt, see
    190     // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
    191     offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
    192     offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) +
    193         (num_futures + 1) * sizeof(weight);
    194     return offset;
    195   }
    196 
    197   void SetInstFuture(StateId state, NGramFstInst<A> *inst) const {
    198     if (inst->state_ != state) {
    199       inst->state_ = state;
    200       const size_t next_zero = future_index_.Select0(state + 1);
    201       const size_t this_zero = future_index_.Select0(state);
    202       inst->num_futures_ = next_zero - this_zero - 1;
    203       inst->offset_ = future_index_.Rank1(future_index_.Select0(state) + 1);
    204     }
    205   }
    206 
    207   void SetInstNode(NGramFstInst<A> *inst) const {
    208     if (inst->node_state_ != inst->state_) {
    209       inst->node_state_ = inst->state_;
    210       inst->node_ = context_index_.Select1(inst->state_);
    211     }
    212   }
    213 
    214   void SetInstContext(NGramFstInst<A> *inst) const {
    215     SetInstNode(inst);
    216     if (inst->context_state_ != inst->state_) {
    217       inst->context_state_ = inst->state_;
    218       inst->context_.clear();
    219       size_t node = inst->node_;
    220       while (node != 0) {
    221         inst->context_.push_back(context_words_[context_index_.Rank1(node)]);
    222         node = context_index_.Select1(context_index_.Rank0(node) - 1);
    223       }
    224     }
    225   }
    226 
    227   // Access to the underlying representation
    228   const char* GetData(size_t* data_size) const {
    229     *data_size = StorageSize();
    230     return data_;
    231   }
    232 
    233   void Init(const char* data, bool owned, MappedFile *file = 0);
    234 
    235   const vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const {
    236     SetInstFuture(s, inst);
    237     SetInstContext(inst);
    238     return inst->context_;
    239   }
    240 
    241   size_t StorageSize() const {
    242     return Storage(num_states_, num_futures_, num_final_);
    243   }
    244 
    245   void GetStates(const vector<Label>& context, vector<StateId> *states) const;
    246 
    247  private:
    248   StateId Transition(const vector<Label> &context, Label future) const;
    249 
    250   // Properties always true for this Fst class.
    251   static const uint64 kStaticProperties = kAcceptor | kIDeterministic |
    252       kODeterministic | kEpsilons | kIEpsilons | kOEpsilons | kILabelSorted |
    253       kOLabelSorted | kWeighted | kCyclic | kInitialAcyclic | kNotTopSorted |
    254       kAccessible | kCoAccessible | kNotString | kExpanded;
    255   // Current file format version.
    256   static const int kFileVersion = 4;
    257   // Minimum file format version supported.
    258   static const int kMinFileVersion = 4;
    259 
    260   MappedFile *data_region_;
    261   const char* data_;
    262   bool owned_;  // True if we own data_
    263   uint64 num_states_, num_futures_, num_final_;
    264   size_t root_num_children_;
    265   const Label *root_children_;
    266   size_t root_first_child_;
    267   // borrowed references
    268   const uint64 *context_, *future_, *final_;
    269   const Label *context_words_, *future_words_;
    270   const Weight *backoff_, *final_probs_, *future_probs_;
    271   BitmapIndex context_index_;
    272   BitmapIndex future_index_;
    273   BitmapIndex final_index_;
    274 
    275   void operator=(const NGramFstImpl<A> &);  // Disallow
    276 };
    277 
    278 template<typename A>
    279 NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
    280     : data_region_(0), data_(0), owned_(false) {
    281   typedef A Arc;
    282   typedef typename Arc::Label Label;
    283   typedef typename Arc::Weight Weight;
    284   typedef typename Arc::StateId StateId;
    285   SetType("ngram");
    286   SetInputSymbols(fst.InputSymbols());
    287   SetOutputSymbols(fst.OutputSymbols());
    288   SetProperties(kStaticProperties);
    289 
    290   // Check basic requirements for an OpenGRM language model Fst.
    291   int64 props = kAcceptor | kIDeterministic | kIEpsilons | kILabelSorted;
    292   if (fst.Properties(props, true) != props) {
    293     FSTERROR() << "NGramFst only accepts OpenGRM langauge models as input";
    294     SetProperties(kError, kError);
    295     return;
    296   }
    297 
    298   int64 num_states = CountStates(fst);
    299   Label* context = new Label[num_states];
    300 
    301   // Find the unigram state by starting from the start state, following
    302   // epsilons.
    303   StateId unigram = fst.Start();
    304   while (1) {
    305     if (unigram == kNoStateId) {
    306       FSTERROR() << "Could not identify unigram state.";
    307       SetProperties(kError, kError);
    308       return;
    309     }
    310     ArcIterator<Fst<A> > aiter(fst, unigram);
    311     if (aiter.Done()) {
    312       LOG(WARNING) << "Unigram state " << unigram << " has no arcs.";
    313       break;
    314     }
    315     if (aiter.Value().ilabel != 0) break;
    316     unigram = aiter.Value().nextstate;
    317   }
    318 
    319   // Each state's context is determined by the subtree it is under from the
    320   // unigram state.
    321   queue<pair<StateId, Label> > label_queue;
    322   vector<bool> visited(num_states);
    323   // Force an epsilon link to the start state.
    324   label_queue.push(make_pair(fst.Start(), 0));
    325   for (ArcIterator<Fst<A> > aiter(fst, unigram);
    326        !aiter.Done(); aiter.Next()) {
    327     label_queue.push(make_pair(aiter.Value().nextstate, aiter.Value().ilabel));
    328   }
    329   // investigate states in breadth first fashion to assign context words.
    330   while (!label_queue.empty()) {
    331     pair<StateId, Label> &now = label_queue.front();
    332     if (!visited[now.first]) {
    333       context[now.first] = now.second;
    334       visited[now.first] = true;
    335       for (ArcIterator<Fst<A> > aiter(fst, now.first);
    336            !aiter.Done(); aiter.Next()) {
    337         const Arc &arc = aiter.Value();
    338         if (arc.ilabel != 0) {
    339           label_queue.push(make_pair(arc.nextstate, now.second));
    340         }
    341       }
    342     }
    343     label_queue.pop();
    344   }
    345   visited.clear();
    346 
    347   // The arc from the start state should be assigned an epsilon to put it
    348   // in front of the all other labels (which makes Start state 1 after
    349   // unigram which is state 0).
    350   context[fst.Start()] = 0;
    351 
    352   // Build the tree of contexts fst by reversing the epsilon arcs from fst.
    353   VectorFst<Arc> context_fst;
    354   uint64 num_final = 0;
    355   for (int i = 0; i < num_states; ++i) {
    356     if (fst.Final(i) != Weight::Zero()) {
    357       ++num_final;
    358     }
    359     context_fst.SetFinal(context_fst.AddState(), fst.Final(i));
    360   }
    361   context_fst.SetStart(unigram);
    362   context_fst.SetInputSymbols(fst.InputSymbols());
    363   context_fst.SetOutputSymbols(fst.OutputSymbols());
    364   int64 num_context_arcs = 0;
    365   int64 num_futures = 0;
    366   for (StateIterator<Fst<A> > siter(fst); !siter.Done(); siter.Next()) {
    367     const StateId &state = siter.Value();
    368     num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state);
    369     ArcIterator<Fst<A> > aiter(fst, state);
    370     if (!aiter.Done()) {
    371       const Arc &arc = aiter.Value();
    372       // this arc goes from state to arc.nextstate, so create an arc from
    373       // arc.nextstate to state to reverse it.
    374       if (arc.ilabel == 0) {
    375         context_fst.AddArc(arc.nextstate, Arc(context[state], context[state],
    376                                               arc.weight, state));
    377         num_context_arcs++;
    378       }
    379     }
    380   }
    381   if (num_context_arcs != context_fst.NumStates() - 1) {
    382     FSTERROR() << "Number of contexts arcs != number of states - 1";
    383     SetProperties(kError, kError);
    384     return;
    385   }
    386   if (context_fst.NumStates() != num_states) {
    387     FSTERROR() << "Number of contexts != number of states";
    388     SetProperties(kError, kError);
    389     return;
    390   }
    391   int64 context_props = context_fst.Properties(kIDeterministic |
    392                                                kILabelSorted, true);
    393   if (!(context_props & kIDeterministic)) {
    394     FSTERROR() << "Input fst is not structured properly";
    395     SetProperties(kError, kError);
    396     return;
    397   }
    398   if (!(context_props & kILabelSorted)) {
    399      ArcSort(&context_fst, ILabelCompare<Arc>());
    400   }
    401 
    402   delete [] context;
    403 
    404   uint64 b64;
    405   Weight weight;
    406   Label label = kNoLabel;
    407   const size_t storage = Storage(num_states, num_futures, num_final);
    408   MappedFile *data_region = MappedFile::Allocate(storage);
    409   char *data = reinterpret_cast<char *>(data_region->mutable_data());
    410   memset(data, 0, storage);
    411   size_t offset = 0;
    412   memcpy(data + offset, reinterpret_cast<char *>(&num_states),
    413          sizeof(num_states));
    414   offset += sizeof(num_states);
    415   memcpy(data + offset, reinterpret_cast<char *>(&num_futures),
    416          sizeof(num_futures));
    417   offset += sizeof(num_futures);
    418   memcpy(data + offset, reinterpret_cast<char *>(&num_final),
    419          sizeof(num_final));
    420   offset += sizeof(num_final);
    421   uint64* context_bits = reinterpret_cast<uint64*>(data + offset);
    422   offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64);
    423   uint64* future_bits = reinterpret_cast<uint64*>(data + offset);
    424   offset +=
    425       BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64);
    426   uint64* final_bits = reinterpret_cast<uint64*>(data + offset);
    427   offset += BitmapIndex::StorageSize(num_states) * sizeof(b64);
    428   Label* context_words = reinterpret_cast<Label*>(data + offset);
    429   offset += (num_states + 1) * sizeof(label);
    430   Label* future_words = reinterpret_cast<Label*>(data + offset);
    431   offset += num_futures * sizeof(label);
    432   offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
    433   Weight* backoff = reinterpret_cast<Weight*>(data + offset);
    434   offset += (num_states + 1) * sizeof(weight);
    435   Weight* final_probs = reinterpret_cast<Weight*>(data + offset);
    436   offset += num_final * sizeof(weight);
    437   Weight* future_probs = reinterpret_cast<Weight*>(data + offset);
    438   int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0,
    439         final_bit = 0;
    440 
    441   // pseudo-root bits
    442   BitmapIndex::Set(context_bits, context_bit++);
    443   ++context_bit;
    444   context_words[context_arc] = label;
    445   backoff[context_arc] = Weight::Zero();
    446   context_arc++;
    447 
    448   ++future_bit;
    449   if (order_out) {
    450     order_out->clear();
    451     order_out->resize(num_states);
    452   }
    453 
    454   queue<StateId> context_q;
    455   context_q.push(context_fst.Start());
    456   StateId state_number = 0;
    457   while (!context_q.empty()) {
    458     const StateId &state = context_q.front();
    459     if (order_out) {
    460       (*order_out)[state] = state_number;
    461     }
    462 
    463     const Weight &final = context_fst.Final(state);
    464     if (final != Weight::Zero()) {
    465       BitmapIndex::Set(final_bits, state_number);
    466       final_probs[final_bit] = final;
    467       ++final_bit;
    468     }
    469 
    470     for (ArcIterator<VectorFst<A> > aiter(context_fst, state);
    471          !aiter.Done(); aiter.Next()) {
    472       const Arc &arc = aiter.Value();
    473       context_words[context_arc] = arc.ilabel;
    474       backoff[context_arc] = arc.weight;
    475       ++context_arc;
    476       BitmapIndex::Set(context_bits, context_bit++);
    477       context_q.push(arc.nextstate);
    478     }
    479     ++context_bit;
    480 
    481     for (ArcIterator<Fst<A> > aiter(fst, state); !aiter.Done(); aiter.Next()) {
    482       const Arc &arc = aiter.Value();
    483       if (arc.ilabel != 0) {
    484         future_words[future_arc] = arc.ilabel;
    485         future_probs[future_arc] = arc.weight;
    486         ++future_arc;
    487         BitmapIndex::Set(future_bits, future_bit++);
    488       }
    489     }
    490     ++future_bit;
    491     ++state_number;
    492     context_q.pop();
    493   }
    494 
    495   if ((state_number !=  num_states) ||
    496       (context_bit != num_states * 2 + 1) ||
    497       (context_arc != num_states) ||
    498       (future_arc != num_futures) ||
    499       (future_bit != num_futures + num_states + 1) ||
    500       (final_bit != num_final)) {
    501     FSTERROR() << "Structure problems detected during construction";
    502     SetProperties(kError, kError);
    503     return;
    504   }
    505 
    506   Init(data, false, data_region);
    507 }
    508 
    509 template<typename A>
    510 inline void NGramFstImpl<A>::Init(const char* data, bool owned,
    511                                   MappedFile *data_region) {
    512   if (owned_) {
    513     delete [] data_;
    514   }
    515   delete data_region_;
    516   data_region_ = data_region;
    517   owned_ = owned;
    518   data_ = data;
    519   size_t offset = 0;
    520   num_states_ = *(reinterpret_cast<const uint64*>(data_ + offset));
    521   offset += sizeof(num_states_);
    522   num_futures_ = *(reinterpret_cast<const uint64*>(data_ + offset));
    523   offset += sizeof(num_futures_);
    524   num_final_ = *(reinterpret_cast<const uint64*>(data_ + offset));
    525   offset += sizeof(num_final_);
    526   uint64 bits;
    527   size_t context_bits = num_states_ * 2 + 1;
    528   size_t future_bits = num_futures_ + num_states_ + 1;
    529   context_ = reinterpret_cast<const uint64*>(data_ + offset);
    530   offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits);
    531   future_ = reinterpret_cast<const uint64*>(data_ + offset);
    532   offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits);
    533   final_ = reinterpret_cast<const uint64*>(data_ + offset);
    534   offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits);
    535   context_words_ = reinterpret_cast<const Label*>(data_ + offset);
    536   offset += (num_states_ + 1) * sizeof(*context_words_);
    537   future_words_ = reinterpret_cast<const Label*>(data_ + offset);
    538   offset += num_futures_ * sizeof(*future_words_);
    539   offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1);
    540   backoff_ = reinterpret_cast<const Weight*>(data_ + offset);
    541   offset += (num_states_ + 1) * sizeof(*backoff_);
    542   final_probs_ = reinterpret_cast<const Weight*>(data_ + offset);
    543   offset += num_final_ * sizeof(*final_probs_);
    544   future_probs_ = reinterpret_cast<const Weight*>(data_ + offset);
    545 
    546   context_index_.BuildIndex(context_, context_bits);
    547   future_index_.BuildIndex(future_, future_bits);
    548   final_index_.BuildIndex(final_, num_states_);
    549 
    550   const size_t node_rank = context_index_.Rank1(0);
    551   root_first_child_ = context_index_.Select0(node_rank) + 1;
    552   if (context_index_.Get(root_first_child_) == false) {
    553     FSTERROR() << "Missing unigrams";
    554     SetProperties(kError, kError);
    555     return;
    556   }
    557   const size_t last_child = context_index_.Select0(node_rank + 1) - 1;
    558   root_num_children_ = last_child - root_first_child_ + 1;
    559   root_children_ = context_words_ + context_index_.Rank1(root_first_child_);
    560 }
    561 
    562 template<typename A>
    563 inline typename A::StateId NGramFstImpl<A>::Transition(
    564         const vector<Label> &context, Label future) const {
    565   const Label *children = root_children_;
    566   const Label *loc = lower_bound(children, children + root_num_children_,
    567                                  future);
    568   if (loc == children + root_num_children_ || *loc != future) {
    569     return context_index_.Rank1(0);
    570   }
    571   size_t node = root_first_child_ + loc - children;
    572   size_t node_rank = context_index_.Rank1(node);
    573   size_t first_child = context_index_.Select0(node_rank) + 1;
    574   if (context_index_.Get(first_child) == false) {
    575     return context_index_.Rank1(node);
    576   }
    577   size_t last_child = context_index_.Select0(node_rank + 1) - 1;
    578   for (int word = context.size() - 1; word >= 0; --word) {
    579     children = context_words_ + context_index_.Rank1(first_child);
    580     loc = lower_bound(children, children + last_child - first_child + 1,
    581                       context[word]);
    582     if (loc == children + last_child - first_child + 1 ||
    583         *loc != context[word]) {
    584       break;
    585     }
    586     node = first_child + loc - children;
    587     node_rank = context_index_.Rank1(node);
    588     first_child = context_index_.Select0(node_rank) + 1;
    589     if (context_index_.Get(first_child) == false) break;
    590     last_child = context_index_.Select0(node_rank + 1) - 1;
    591   }
    592   return context_index_.Rank1(node);
    593 }
    594 
    595 template<typename A>
    596 inline void NGramFstImpl<A>::GetStates(
    597     const vector<Label> &context,
    598     vector<typename A::StateId>* states) const {
    599   states->clear();
    600   states->push_back(0);
    601   typename vector<Label>::const_reverse_iterator cit = context.rbegin();
    602   const Label *children = root_children_;
    603   const Label *loc = lower_bound(children, children + root_num_children_, *cit);
    604   if (loc == children + root_num_children_ || *loc != *cit) return;
    605   size_t node = root_first_child_ + loc - children;
    606   states->push_back(context_index_.Rank1(node));
    607   if (context.size() == 1) return;
    608   size_t node_rank = context_index_.Rank1(node);
    609   size_t first_child = context_index_.Select0(node_rank) + 1;
    610   ++cit;
    611   if (context_index_.Get(first_child) != false) {
    612     size_t last_child = context_index_.Select0(node_rank + 1) - 1;
    613     while (cit != context.rend()) {
    614       children = context_words_ + context_index_.Rank1(first_child);
    615       loc = lower_bound(children, children + last_child - first_child + 1,
    616                         *cit);
    617       if (loc == children + last_child - first_child + 1 || *loc != *cit) {
    618         break;
    619       }
    620       ++cit;
    621       node = first_child + loc - children;
    622       states->push_back(context_index_.Rank1(node));
    623       node_rank = context_index_.Rank1(node);
    624       first_child = context_index_.Select0(node_rank) + 1;
    625       if (context_index_.Get(first_child) == false) break;
    626       last_child = context_index_.Select0(node_rank + 1) - 1;
    627     }
    628   }
    629 }
    630 
    631 /*****************************************************************************/
    632 template<class A>
    633 class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
    634   friend class ArcIterator<NGramFst<A> >;
    635   friend class NGramFstMatcher<A>;
    636 
    637  public:
    638   typedef A Arc;
    639   typedef typename A::StateId StateId;
    640   typedef typename A::Label Label;
    641   typedef typename A::Weight Weight;
    642   typedef NGramFstImpl<A> Impl;
    643 
    644   explicit NGramFst(const Fst<A> &dst)
    645       : ImplToExpandedFst<Impl>(new Impl(dst, NULL)) {}
    646 
    647   NGramFst(const Fst<A> &fst, vector<StateId>* order_out)
    648       : ImplToExpandedFst<Impl>(new Impl(fst, order_out)) {}
    649 
    650   // Because the NGramFstImpl is a const stateless data structure, there
    651   // is never a need to do anything beside copy the reference.
    652   NGramFst(const NGramFst<A> &fst, bool safe = false)
    653       : ImplToExpandedFst<Impl>(fst, false) {}
    654 
    655   NGramFst() : ImplToExpandedFst<Impl>(new Impl()) {}
    656 
    657   // Non-standard constructor to initialize NGramFst directly from data.
    658   NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) {
    659     GetImpl()->Init(data, owned, NULL);
    660   }
    661 
    662   // Get method that gets the data associated with Init().
    663   const char* GetData(size_t* data_size) const {
    664     return GetImpl()->GetData(data_size);
    665   }
    666 
    667   const vector<Label> GetContext(StateId s) const {
    668     return GetImpl()->GetContext(s, &inst_);
    669   }
    670 
    671   // Consumes as much as possible of context from right to left, returns the
    672   // the states corresponding to the increasingly conditioned input sequence.
    673   void GetStates(const vector<Label>& context, vector<StateId> *state) const {
    674     return GetImpl()->GetStates(context, state);
    675   }
    676 
    677   virtual size_t NumArcs(StateId s) const {
    678     return GetImpl()->NumArcs(s, &inst_);
    679   }
    680 
    681   virtual NGramFst<A>* Copy(bool safe = false) const {
    682     return new NGramFst(*this, safe);
    683   }
    684 
    685   static NGramFst<A>* Read(istream &strm, const FstReadOptions &opts) {
    686     Impl* impl = Impl::Read(strm, opts);
    687     return impl ? new NGramFst<A>(impl) : 0;
    688   }
    689 
    690   static NGramFst<A>* Read(const string &filename) {
    691     if (!filename.empty()) {
    692       ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
    693       if (!strm) {
    694         LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename;
    695         return 0;
    696       }
    697       return Read(strm, FstReadOptions(filename));
    698     } else {
    699       return Read(cin, FstReadOptions("standard input"));
    700     }
    701   }
    702 
    703   virtual bool Write(ostream &strm, const FstWriteOptions &opts) const {
    704     return GetImpl()->Write(strm, opts);
    705   }
    706 
    707   virtual bool Write(const string &filename) const {
    708     return Fst<A>::WriteFile(filename);
    709   }
    710 
    711   virtual inline void InitStateIterator(StateIteratorData<A>* data) const {
    712     GetImpl()->InitStateIterator(data);
    713   }
    714 
    715   virtual inline void InitArcIterator(
    716       StateId s, ArcIteratorData<A>* data) const;
    717 
    718   virtual MatcherBase<A>* InitMatcher(MatchType match_type) const {
    719     return new NGramFstMatcher<A>(*this, match_type);
    720   }
    721 
    722   size_t StorageSize() const {
    723     return GetImpl()->StorageSize();
    724   }
    725 
    726  private:
    727   explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {}
    728 
    729   Impl* GetImpl() const {
    730     return
    731         ImplToExpandedFst<Impl, ExpandedFst<A> >::GetImpl();
    732   }
    733 
    734   void SetImpl(Impl* impl, bool own_impl = true) {
    735     ImplToExpandedFst<Impl, Fst<A> >::SetImpl(impl, own_impl);
    736   }
    737 
    738   mutable NGramFstInst<A> inst_;
    739 };
    740 
    741 template <class A> inline void
    742 NGramFst<A>::InitArcIterator(StateId s, ArcIteratorData<A>* data) const {
    743   GetImpl()->SetInstFuture(s, &inst_);
    744   GetImpl()->SetInstNode(&inst_);
    745   data->base = new ArcIterator<NGramFst<A> >(*this, s);
    746 }
    747 
    748 /*****************************************************************************/
    749 template <class A>
    750 class NGramFstMatcher : public MatcherBase<A> {
    751  public:
    752   typedef A Arc;
    753   typedef typename A::Label Label;
    754   typedef typename A::StateId StateId;
    755   typedef typename A::Weight Weight;
    756 
    757   NGramFstMatcher(const NGramFst<A> &fst, MatchType match_type)
    758       : fst_(fst), inst_(fst.inst_), match_type_(match_type),
    759         current_loop_(false),
    760         loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
    761     if (match_type_ == MATCH_OUTPUT) {
    762       swap(loop_.ilabel, loop_.olabel);
    763     }
    764   }
    765 
    766   NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false)
    767       : fst_(matcher.fst_), inst_(matcher.inst_),
    768         match_type_(matcher.match_type_), current_loop_(false),
    769         loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
    770     if (match_type_ == MATCH_OUTPUT) {
    771       swap(loop_.ilabel, loop_.olabel);
    772     }
    773   }
    774 
    775   virtual NGramFstMatcher<A>* Copy(bool safe = false) const {
    776     return new NGramFstMatcher<A>(*this, safe);
    777   }
    778 
    779   virtual MatchType Type(bool test) const {
    780     return match_type_;
    781   }
    782 
    783   virtual const Fst<A> &GetFst() const {
    784     return fst_;
    785   }
    786 
    787   virtual uint64 Properties(uint64 props) const {
    788     return props;
    789   }
    790 
    791  private:
    792   virtual void SetState_(StateId s) {
    793     fst_.GetImpl()->SetInstFuture(s, &inst_);
    794     current_loop_ = false;
    795   }
    796 
    797   virtual bool Find_(Label label) {
    798     const Label nolabel = kNoLabel;
    799     done_ = true;
    800     if (label == 0 || label == nolabel) {
    801       if (label == 0) {
    802         current_loop_ = true;
    803         loop_.nextstate = inst_.state_;
    804       }
    805       // The unigram state has no epsilon arc.
    806       if (inst_.state_ != 0) {
    807         arc_.ilabel = arc_.olabel = 0;
    808         fst_.GetImpl()->SetInstNode(&inst_);
    809         arc_.nextstate = fst_.GetImpl()->context_index_.Rank1(
    810             fst_.GetImpl()->context_index_.Select1(
    811                 fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1));
    812         arc_.weight = fst_.GetImpl()->backoff_[inst_.state_];
    813         done_ = false;
    814       }
    815     } else {
    816       const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_;
    817       const Label *end = start + inst_.num_futures_;
    818       const Label* search = lower_bound(start, end, label);
    819       if (search != end && *search == label) {
    820         size_t state = search - start;
    821         arc_.ilabel = arc_.olabel = label;
    822         arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state];
    823         fst_.GetImpl()->SetInstContext(&inst_);
    824         arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label);
    825         done_ = false;
    826       }
    827     }
    828     return !Done_();
    829   }
    830 
    831   virtual bool Done_() const {
    832     return !current_loop_ && done_;
    833   }
    834 
    835   virtual const Arc& Value_() const {
    836     return (current_loop_) ? loop_ : arc_;
    837   }
    838 
    839   virtual void Next_() {
    840     if (current_loop_) {
    841       current_loop_ = false;
    842     } else {
    843       done_ = true;
    844     }
    845   }
    846 
    847   const NGramFst<A>& fst_;
    848   NGramFstInst<A> inst_;
    849   MatchType match_type_;             // Supplied by caller
    850   bool done_;
    851   Arc arc_;
    852   bool current_loop_;                // Current arc is the implicit loop
    853   Arc loop_;
    854 };
    855 
    856 /*****************************************************************************/
    857 template<class A>
    858 class ArcIterator<NGramFst<A> > : public ArcIteratorBase<A> {
    859  public:
    860   typedef A Arc;
    861   typedef typename A::Label Label;
    862   typedef typename A::StateId StateId;
    863   typedef typename A::Weight Weight;
    864 
    865   ArcIterator(const NGramFst<A> &fst, StateId state)
    866       : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) {
    867     inst_ = fst.inst_;
    868     impl_->SetInstFuture(state, &inst_);
    869     impl_->SetInstNode(&inst_);
    870   }
    871 
    872   bool Done() const {
    873     return i_ >= ((inst_.node_ == 0) ? inst_.num_futures_ :
    874                   inst_.num_futures_ + 1);
    875   }
    876 
    877   const Arc &Value() const {
    878     bool eps = (inst_.node_ != 0 && i_ == 0);
    879     StateId state = (inst_.node_ == 0) ? i_ : i_ - 1;
    880     if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) {
    881       arc_.ilabel =
    882           arc_.olabel = eps ? 0 : impl_->future_words_[inst_.offset_ + state];
    883       lazy_ &= ~(kArcILabelValue | kArcOLabelValue);
    884     }
    885     if (flags_ & lazy_ & kArcNextStateValue) {
    886       if (eps) {
    887         arc_.nextstate = impl_->context_index_.Rank1(
    888             impl_->context_index_.Select1(
    889                 impl_->context_index_.Rank0(inst_.node_) - 1));
    890       } else {
    891         if (lazy_ & kArcNextStateValue) {
    892           impl_->SetInstContext(&inst_);  // first time only.
    893         }
    894         arc_.nextstate =
    895             impl_->Transition(inst_.context_,
    896                               impl_->future_words_[inst_.offset_ + state]);
    897       }
    898       lazy_ &= ~kArcNextStateValue;
    899     }
    900     if (flags_ & lazy_ & kArcWeightValue) {
    901       arc_.weight = eps ?  impl_->backoff_[inst_.state_] :
    902           impl_->future_probs_[inst_.offset_ + state];
    903       lazy_ &= ~kArcWeightValue;
    904     }
    905     return arc_;
    906   }
    907 
    908   void Next() {
    909     ++i_;
    910     lazy_ = ~0;
    911   }
    912 
    913   size_t Position() const { return i_; }
    914 
    915   void Reset() {
    916     i_ = 0;
    917     lazy_ = ~0;
    918   }
    919 
    920   void Seek(size_t a) {
    921     if (i_ != a) {
    922       i_ = a;
    923       lazy_ = ~0;
    924     }
    925   }
    926 
    927   uint32 Flags() const {
    928     return flags_;
    929   }
    930 
    931   void SetFlags(uint32 f, uint32 m) {
    932     flags_ &= ~m;
    933     flags_ |= (f & kArcValueFlags);
    934   }
    935 
    936  private:
    937   virtual bool Done_() const { return Done(); }
    938   virtual const Arc& Value_() const { return Value(); }
    939   virtual void Next_() { Next(); }
    940   virtual size_t Position_() const { return Position(); }
    941   virtual void Reset_() { Reset(); }
    942   virtual void Seek_(size_t a) { Seek(a); }
    943   uint32 Flags_() const { return Flags(); }
    944   void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); }
    945 
    946   mutable Arc arc_;
    947   mutable uint32 lazy_;
    948   const NGramFstImpl<A> *impl_;
    949   mutable NGramFstInst<A> inst_;
    950 
    951   size_t i_;
    952   uint32 flags_;
    953 
    954   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
    955 };
    956 
    957 /*****************************************************************************/
    958 // Specialization for NGramFst; see generic version in fst.h
    959 // for sample usage (but use the ProdLmFst type!). This version
    960 // should inline.
    961 template <class A>
    962 class StateIterator<NGramFst<A> > : public StateIteratorBase<A> {
    963   public:
    964   typedef typename A::StateId StateId;
    965 
    966   explicit StateIterator(const NGramFst<A> &fst)
    967     : s_(0), num_states_(fst.NumStates()) { }
    968 
    969   bool Done() const { return s_ >= num_states_; }
    970   StateId Value() const { return s_; }
    971   void Next() { ++s_; }
    972   void Reset() { s_ = 0; }
    973 
    974  private:
    975   virtual bool Done_() const { return Done(); }
    976   virtual StateId Value_() const { return Value(); }
    977   virtual void Next_() { Next(); }
    978   virtual void Reset_() { Reset(); }
    979 
    980   StateId s_, num_states_;
    981 
    982   DISALLOW_COPY_AND_ASSIGN(StateIterator);
    983 };
    984 }  // namespace fst
    985 #endif  // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
    986