Home | History | Annotate | Download | only in fst
      1 // cache.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // An Fst implementation that caches FST elements of a delayed
     20 // computation.
     21 
     22 #ifndef FST_LIB_CACHE_H__
     23 #define FST_LIB_CACHE_H__
     24 
     25 #include <vector>
     26 using std::vector;
     27 #include <list>
     28 
     29 #include <fst/vector-fst.h>
     30 
     31 
     32 DECLARE_bool(fst_default_cache_gc);
     33 DECLARE_int64(fst_default_cache_gc_limit);
     34 
     35 namespace fst {
     36 
     37 struct CacheOptions {
     38   bool gc;          // enable GC
     39   size_t gc_limit;  // # of bytes allowed before GC
     40 
     41   CacheOptions(bool g, size_t l) : gc(g), gc_limit(l) {}
     42   CacheOptions()
     43       : gc(FLAGS_fst_default_cache_gc),
     44         gc_limit(FLAGS_fst_default_cache_gc_limit) {}
     45 };
     46 
     47 // A CacheStateAllocator allocates and frees CacheStates
     48 // template <class S>
     49 // struct CacheStateAllocator {
     50 //   S *Allocate(StateId s);
     51 //   void Free(S *state, StateId s);
     52 // };
     53 //
     54 
     55 // A simple allocator class, can be overridden as needed,
     56 // maintains a single entry cache.
     57 template <class S>
     58 struct DefaultCacheStateAllocator {
     59   typedef typename S::Arc::StateId StateId;
     60 
     61   DefaultCacheStateAllocator() : mru_(NULL) { }
     62 
     63   ~DefaultCacheStateAllocator() {
     64     delete mru_;
     65   }
     66 
     67   S *Allocate(StateId s) {
     68     if (mru_) {
     69       S *state = mru_;
     70       mru_ = NULL;
     71       state->Reset();
     72       return state;
     73     }
     74     return new S();
     75   }
     76 
     77   void Free(S *state, StateId s) {
     78     if (mru_) {
     79       delete mru_;
     80     }
     81     mru_ = state;
     82   }
     83 
     84  private:
     85   S *mru_;
     86 };
     87 
     88 // VectorState but additionally has a flags data member (see
     89 // CacheState below). This class is used to cache FST elements with
     90 // the flags used to indicate what has been cached. Use HasStart()
     91 // HasFinal(), and HasArcs() to determine if cached and SetStart(),
     92 // SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note
     93 // you must set the final weight even if the state is non-final to
     94 // mark it as cached. If the 'gc' option is 'false', cached items have
     95 // the extent of the FST - minimizing computation. If the 'gc' option
     96 // is 'true', garbage collection of states (not in use in an arc
     97 // iterator and not 'protected') is performed, in a rough
     98 // approximation of LRU order, when 'gc_limit' bytes is reached -
     99 // controlling memory use. When 'gc_limit' is 0, special optimizations
    100 // apply - minimizing memory use.
    101 
    102 template <class S, class C = DefaultCacheStateAllocator<S> >
    103 class CacheBaseImpl : public VectorFstBaseImpl<S> {
    104  public:
    105   typedef S State;
    106   typedef C Allocator;
    107   typedef typename State::Arc Arc;
    108   typedef typename Arc::Weight Weight;
    109   typedef typename Arc::StateId StateId;
    110 
    111   using FstImpl<Arc>::Type;
    112   using FstImpl<Arc>::Properties;
    113   using FstImpl<Arc>::SetProperties;
    114   using VectorFstBaseImpl<State>::NumStates;
    115   using VectorFstBaseImpl<State>::Start;
    116   using VectorFstBaseImpl<State>::AddState;
    117   using VectorFstBaseImpl<State>::SetState;
    118   using VectorFstBaseImpl<State>::ReserveStates;
    119 
    120   explicit CacheBaseImpl(C *allocator = 0)
    121       : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0),
    122         cache_first_state_id_(kNoStateId), cache_first_state_(0),
    123         cache_gc_(FLAGS_fst_default_cache_gc),  cache_size_(0),
    124         cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit ||
    125                      FLAGS_fst_default_cache_gc_limit == 0 ?
    126                      FLAGS_fst_default_cache_gc_limit : kMinCacheLimit),
    127         protect_(false) {
    128     allocator_ = allocator ? allocator : new C();
    129   }
    130 
    131   explicit CacheBaseImpl(const CacheOptions &opts, C *allocator = 0)
    132       : cache_start_(false), nknown_states_(0),
    133         min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
    134         cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0),
    135         cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ?
    136                      opts.gc_limit : kMinCacheLimit),
    137         protect_(false) {
    138     allocator_ = allocator ? allocator : new C();
    139   }
    140 
    141   // Preserve gc parameters. If preserve_cache true, also preserves
    142   // cache data.
    143   CacheBaseImpl(const CacheBaseImpl<S, C> &impl, bool preserve_cache = false)
    144       : VectorFstBaseImpl<S>(), cache_start_(false), nknown_states_(0),
    145         min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
    146         cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0),
    147         cache_limit_(impl.cache_limit_),
    148         protect_(impl.protect_) {
    149     allocator_ = new C();
    150     if (preserve_cache) {
    151       cache_start_ = impl.cache_start_;
    152       nknown_states_ = impl.nknown_states_;
    153       expanded_states_ = impl.expanded_states_;
    154       min_unexpanded_state_id_ = impl.min_unexpanded_state_id_;
    155       if (impl.cache_first_state_id_ != kNoStateId) {
    156         cache_first_state_id_ = impl.cache_first_state_id_;
    157         cache_first_state_ = allocator_->Allocate(cache_first_state_id_);
    158         *cache_first_state_ = *impl.cache_first_state_;
    159       }
    160       cache_states_ = impl.cache_states_;
    161       cache_size_ = impl.cache_size_;
    162       ReserveStates(impl.NumStates());
    163       for (StateId s = 0; s < impl.NumStates(); ++s) {
    164         const S *state =
    165             static_cast<const VectorFstBaseImpl<S> &>(impl).GetState(s);
    166         if (state) {
    167           S *copied_state = allocator_->Allocate(s);
    168           *copied_state = *state;
    169           AddState(copied_state);
    170         } else {
    171           AddState(0);
    172         }
    173       }
    174       VectorFstBaseImpl<S>::SetStart(impl.Start());
    175     }
    176   }
    177 
    178   ~CacheBaseImpl() {
    179     allocator_->Free(cache_first_state_, cache_first_state_id_);
    180     delete allocator_;
    181   }
    182 
    183   // Gets a state from its ID; state must exist.
    184   const S *GetState(StateId s) const {
    185     if (s == cache_first_state_id_)
    186       return cache_first_state_;
    187     else
    188       return VectorFstBaseImpl<S>::GetState(s);
    189   }
    190 
    191   // Gets a state from its ID; state must exist.
    192   S *GetState(StateId s) {
    193     if (s == cache_first_state_id_)
    194       return cache_first_state_;
    195     else
    196       return VectorFstBaseImpl<S>::GetState(s);
    197   }
    198 
    199   // Gets a state from its ID; return 0 if it doesn't exist.
    200   const S *CheckState(StateId s) const {
    201     if (s == cache_first_state_id_)
    202       return cache_first_state_;
    203     else if (s < NumStates())
    204       return VectorFstBaseImpl<S>::GetState(s);
    205     else
    206       return 0;
    207   }
    208 
    209   // Gets a state from its ID; add it if necessary.
    210   S *ExtendState(StateId s);
    211 
    212   void SetStart(StateId s) {
    213     VectorFstBaseImpl<S>::SetStart(s);
    214     cache_start_ = true;
    215     if (s >= nknown_states_)
    216       nknown_states_ = s + 1;
    217   }
    218 
    219   void SetFinal(StateId s, Weight w) {
    220     S *state = ExtendState(s);
    221     state->final = w;
    222     state->flags |= kCacheFinal | kCacheRecent | kCacheModified;
    223   }
    224 
    225   // AddArc adds a single arc to state s and does incremental cache
    226   // book-keeping.  For efficiency, prefer PushArc and SetArcs below
    227   // when possible.
    228   void AddArc(StateId s, const Arc &arc) {
    229     S *state = ExtendState(s);
    230     state->arcs.push_back(arc);
    231     if (arc.ilabel == 0) {
    232       ++state->niepsilons;
    233     }
    234     if (arc.olabel == 0) {
    235       ++state->noepsilons;
    236     }
    237     const Arc *parc = state->arcs.empty() ? 0 : &(state->arcs.back());
    238     SetProperties(AddArcProperties(Properties(), s, arc, parc));
    239     state->flags |= kCacheModified;
    240     if (cache_gc_ && s != cache_first_state_id_ &&
    241         !(state->flags & kCacheProtect)) {
    242       cache_size_ += sizeof(Arc);
    243       if (cache_size_ > cache_limit_)
    244         GC(s, false);
    245     }
    246   }
    247 
    248   // Adds a single arc to state s but delays cache book-keeping.
    249   // SetArcs must be called when all PushArc calls at a state are
    250   // complete.  Do not mix with calls to AddArc.
    251   void PushArc(StateId s, const Arc &arc) {
    252     S *state = ExtendState(s);
    253     state->arcs.push_back(arc);
    254   }
    255 
    256   // Marks arcs of state s as cached and does cache book-keeping after all
    257   // calls to PushArc have been completed.  Do not mix with calls to AddArc.
    258   void SetArcs(StateId s) {
    259     S *state = ExtendState(s);
    260     vector<Arc> &arcs = state->arcs;
    261     state->niepsilons = state->noepsilons = 0;
    262     for (size_t a = 0; a < arcs.size(); ++a) {
    263       const Arc &arc = arcs[a];
    264       if (arc.nextstate >= nknown_states_)
    265         nknown_states_ = arc.nextstate + 1;
    266       if (arc.ilabel == 0)
    267         ++state->niepsilons;
    268       if (arc.olabel == 0)
    269         ++state->noepsilons;
    270     }
    271     ExpandedState(s);
    272     state->flags |= kCacheArcs | kCacheRecent | kCacheModified;
    273     if (cache_gc_ && s != cache_first_state_id_ &&
    274         !(state->flags & kCacheProtect)) {
    275       cache_size_ += arcs.capacity() * sizeof(Arc);
    276       if (cache_size_ > cache_limit_)
    277         GC(s, false);
    278     }
    279   };
    280 
    281   void ReserveArcs(StateId s, size_t n) {
    282     S *state = ExtendState(s);
    283     state->arcs.reserve(n);
    284   }
    285 
    286   void DeleteArcs(StateId s, size_t n) {
    287     S *state = ExtendState(s);
    288     const vector<Arc> &arcs = state->arcs;
    289     for (size_t i = 0; i < n; ++i) {
    290       size_t j = arcs.size() - i - 1;
    291       if (arcs[j].ilabel == 0)
    292         --state->niepsilons;
    293       if (arcs[j].olabel == 0)
    294         --state->noepsilons;
    295     }
    296 
    297     state->arcs.resize(arcs.size() - n);
    298     SetProperties(DeleteArcsProperties(Properties()));
    299     state->flags |= kCacheModified;
    300     if (cache_gc_ && s != cache_first_state_id_ &&
    301         !(state->flags & kCacheProtect)) {
    302       cache_size_ -= n * sizeof(Arc);
    303     }
    304   }
    305 
    306   void DeleteArcs(StateId s) {
    307     S *state = ExtendState(s);
    308     size_t n = state->arcs.size();
    309     state->niepsilons = 0;
    310     state->noepsilons = 0;
    311     state->arcs.clear();
    312     SetProperties(DeleteArcsProperties(Properties()));
    313     state->flags |= kCacheModified;
    314     if (cache_gc_ && s != cache_first_state_id_ &&
    315         !(state->flags & kCacheProtect)) {
    316       cache_size_ -= n * sizeof(Arc);
    317     }
    318   }
    319 
    320   void DeleteStates(const vector<StateId> &dstates) {
    321     size_t old_num_states = NumStates();
    322     vector<StateId> newid(old_num_states, 0);
    323     for (size_t i = 0; i < dstates.size(); ++i)
    324       newid[dstates[i]] = kNoStateId;
    325     StateId nstates = 0;
    326     for (StateId s = 0; s < old_num_states; ++s) {
    327       if (newid[s] != kNoStateId) {
    328         newid[s] = nstates;
    329         ++nstates;
    330       }
    331     }
    332     // just for states_.resize(), does unnecessary walk.
    333     VectorFstBaseImpl<S>::DeleteStates(dstates);
    334     SetProperties(DeleteStatesProperties(Properties()));
    335     // Update list of cached states.
    336     typename list<StateId>::iterator siter = cache_states_.begin();
    337     while (siter != cache_states_.end()) {
    338       if (newid[*siter] != kNoStateId) {
    339         *siter = newid[*siter];
    340         ++siter;
    341       } else {
    342         cache_states_.erase(siter++);
    343       }
    344     }
    345   }
    346 
    347   void DeleteStates() {
    348     cache_states_.clear();
    349     allocator_->Free(cache_first_state_, cache_first_state_id_);
    350     for (int s = 0; s < NumStates(); ++s) {
    351       allocator_->Free(VectorFstBaseImpl<S>::GetState(s), s);
    352       SetState(s, 0);
    353     }
    354     nknown_states_ = 0;
    355     min_unexpanded_state_id_ = 0;
    356     cache_first_state_id_ = kNoStateId;
    357     cache_first_state_ = 0;
    358     cache_size_ = 0;
    359     cache_start_ = false;
    360     VectorFstBaseImpl<State>::DeleteStates();
    361     SetProperties(DeleteAllStatesProperties(Properties(),
    362                                             kExpanded | kMutable));
    363   }
    364 
    365   // Is the start state cached?
    366   bool HasStart() const {
    367     if (!cache_start_ && Properties(kError))
    368       cache_start_ = true;
    369     return cache_start_;
    370   }
    371 
    372   // Is the final weight of state s cached?
    373   bool HasFinal(StateId s) const {
    374     const S *state = CheckState(s);
    375     if (state && state->flags & kCacheFinal) {
    376       state->flags |= kCacheRecent;
    377       return true;
    378     } else {
    379       return false;
    380     }
    381   }
    382 
    383   // Are arcs of state s cached?
    384   bool HasArcs(StateId s) const {
    385     const S *state = CheckState(s);
    386     if (state && state->flags & kCacheArcs) {
    387       state->flags |= kCacheRecent;
    388       return true;
    389     } else {
    390       return false;
    391     }
    392   }
    393 
    394   Weight Final(StateId s) const {
    395     const S *state = GetState(s);
    396     return state->final;
    397   }
    398 
    399   size_t NumArcs(StateId s) const {
    400     const S *state = GetState(s);
    401     return state->arcs.size();
    402   }
    403 
    404   size_t NumInputEpsilons(StateId s) const {
    405     const S *state = GetState(s);
    406     return state->niepsilons;
    407   }
    408 
    409   size_t NumOutputEpsilons(StateId s) const {
    410     const S *state = GetState(s);
    411     return state->noepsilons;
    412   }
    413 
    414   // Provides information needed for generic arc iterator.
    415   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
    416     const S *state = GetState(s);
    417     data->base = 0;
    418     data->narcs = state->arcs.size();
    419     data->arcs = data->narcs > 0 ? &(state->arcs[0]) : 0;
    420     data->ref_count = &(state->ref_count);
    421     ++(*data->ref_count);
    422   }
    423 
    424   // Number of known states.
    425   StateId NumKnownStates() const { return nknown_states_; }
    426 
    427   // Update number of known states taking in account the existence of state s.
    428   void UpdateNumKnownStates(StateId s) {
    429     if (s >= nknown_states_)
    430       nknown_states_ = s + 1;
    431   }
    432 
    433   // Find the mininum never-expanded state Id
    434   StateId MinUnexpandedState() const {
    435     while (min_unexpanded_state_id_ < expanded_states_.size() &&
    436           expanded_states_[min_unexpanded_state_id_])
    437       ++min_unexpanded_state_id_;
    438     return min_unexpanded_state_id_;
    439   }
    440 
    441   // Removes from cache_states_ and uncaches (not referenced-counted
    442   // or protected) states that have not been accessed since the last
    443   // GC until at most cache_fraction * cache_limit_ bytes are cached.
    444   // If that fails to free enough, recurs uncaching recently visited
    445   // states as well. If still unable to free enough memory, then
    446   // widens cache_limit_ to fulfill condition.
    447   void GC(StateId current, bool free_recent,  float cache_fraction = 0.666);
    448 
    449   // Setc/clears GC protection: if true, new states are protected
    450   // from garbage collection.
    451   void GCProtect(bool on) { protect_ = on; }
    452 
    453   void ExpandedState(StateId s) {
    454     if (s < min_unexpanded_state_id_)
    455       return;
    456     while (expanded_states_.size() <= s)
    457       expanded_states_.push_back(false);
    458     expanded_states_[s] = true;
    459   }
    460 
    461   C *GetAllocator() const {
    462     return allocator_;
    463   }
    464 
    465   // Caching on/off switch, limit and size accessors.
    466   bool GetCacheGc() const { return cache_gc_; }
    467   size_t GetCacheLimit() const { return cache_limit_; }
    468   size_t GetCacheSize() const { return cache_size_; }
    469 
    470  private:
    471   static const size_t kMinCacheLimit = 8096;   // Minimum (non-zero) cache limit
    472 
    473   static const uint32 kCacheFinal =    0x0001;  // Final weight has been cached
    474   static const uint32 kCacheArcs =     0x0002;  // Arcs have been cached
    475   static const uint32 kCacheRecent =   0x0004;  // Mark as visited since GC
    476   static const uint32 kCacheProtect =  0x0008;  // Mark state as GC protected
    477 
    478  public:
    479   static const uint32 kCacheModified = 0x0010;  // Mark state as modified
    480   static const uint32 kCacheFlags = kCacheFinal | kCacheArcs | kCacheRecent
    481       | kCacheProtect | kCacheModified;
    482 
    483  private:
    484   C *allocator_;                             // used to allocate new states
    485   mutable bool cache_start_;                 // Is the start state cached?
    486   StateId nknown_states_;                    // # of known states
    487   vector<bool> expanded_states_;             // states that have been expanded
    488   mutable StateId min_unexpanded_state_id_;  // minimum never-expanded state Id
    489   StateId cache_first_state_id_;             // First cached state id
    490   S *cache_first_state_;                     // First cached state
    491   list<StateId> cache_states_;               // list of currently cached states
    492   bool cache_gc_;                            // enable GC
    493   size_t cache_size_;                        // # of bytes cached
    494   size_t cache_limit_;                       // # of bytes allowed before GC
    495   bool protect_;                             // Protect new states from GC
    496 
    497   void operator=(const CacheBaseImpl<S, C> &impl);    // disallow
    498 };
    499 
    500 // Gets a state from its ID; add it if necessary.
    501 template <class S, class C>
    502 S *CacheBaseImpl<S, C>::ExtendState(typename S::Arc::StateId s) {
    503   // If 'protect_' true and a new state, protects from garbage collection.
    504   if (s == cache_first_state_id_) {
    505     return cache_first_state_;                   // Return 1st cached state
    506   } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) {
    507     cache_first_state_id_ = s;                   // Remember 1st cached state
    508     cache_first_state_ = allocator_->Allocate(s);
    509     if (protect_) cache_first_state_->flags |= kCacheProtect;
    510     return cache_first_state_;
    511   } else if (cache_first_state_id_ != kNoStateId &&
    512              cache_first_state_->ref_count == 0 &&
    513              !(cache_first_state_->flags & kCacheProtect)) {
    514     // With Default allocator, the Free and Allocate will reuse the same S*.
    515     allocator_->Free(cache_first_state_, cache_first_state_id_);
    516     cache_first_state_id_ = s;
    517     cache_first_state_ = allocator_->Allocate(s);
    518     if (protect_) cache_first_state_->flags |= kCacheProtect;
    519     return cache_first_state_;                   // Return 1st cached state
    520   } else {
    521     while (NumStates() <= s)                     // Add state to main cache
    522       AddState(0);
    523     S *state = VectorFstBaseImpl<S>::GetState(s);
    524     if (!state) {
    525       state = allocator_->Allocate(s);
    526       if (protect_) state->flags |= kCacheProtect;
    527       SetState(s, state);
    528       if (cache_first_state_id_ != kNoStateId) {  // Forget 1st cached state
    529         while (NumStates() <= cache_first_state_id_)
    530           AddState(0);
    531         SetState(cache_first_state_id_, cache_first_state_);
    532         if (cache_gc_ && !(cache_first_state_->flags & kCacheProtect)) {
    533           cache_states_.push_back(cache_first_state_id_);
    534           cache_size_ += sizeof(S) +
    535                          cache_first_state_->arcs.capacity() * sizeof(Arc);
    536         }
    537         cache_limit_ = kMinCacheLimit;
    538         cache_first_state_id_ = kNoStateId;
    539         cache_first_state_ = 0;
    540       }
    541       if (cache_gc_ && !protect_) {
    542         cache_states_.push_back(s);
    543         cache_size_ += sizeof(S);
    544         if (cache_size_ > cache_limit_)
    545           GC(s, false);
    546       }
    547     }
    548     return state;
    549   }
    550 }
    551 
    552 // Removes from cache_states_ and uncaches (not referenced-counted or
    553 // protected) states that have not been accessed since the last GC
    554 // until at most cache_fraction * cache_limit_ bytes are cached.  If
    555 // that fails to free enough, recurs uncaching recently visited states
    556 // as well. If still unable to free enough memory, then widens cache_limit_
    557 // to fulfill condition.
    558 template <class S, class C>
    559 void CacheBaseImpl<S, C>::GC(typename S::Arc::StateId current,
    560                              bool free_recent, float cache_fraction) {
    561   if (!cache_gc_)
    562     return;
    563   VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this
    564           << "), free recently cached = " << free_recent
    565           << ", cache size = " << cache_size_
    566           << ", cache frac = " << cache_fraction
    567           << ", cache limit = " << cache_limit_ << "\n";
    568   typename list<StateId>::iterator siter = cache_states_.begin();
    569 
    570   size_t cache_target = cache_fraction * cache_limit_;
    571   while (siter != cache_states_.end()) {
    572     StateId s = *siter;
    573     S* state = VectorFstBaseImpl<S>::GetState(s);
    574     if (cache_size_ > cache_target && state->ref_count == 0 &&
    575         (free_recent || !(state->flags & kCacheRecent)) && s != current) {
    576       cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc);
    577       allocator_->Free(state, s);
    578       SetState(s, 0);
    579       cache_states_.erase(siter++);
    580     } else {
    581       state->flags &= ~kCacheRecent;
    582       ++siter;
    583     }
    584   }
    585   if (!free_recent && cache_size_ > cache_target) {   // recurses on recent
    586     GC(current, true);
    587   } else if (cache_target > 0) {                      // widens cache limit
    588     while (cache_size_ > cache_target) {
    589       cache_limit_ *= 2;
    590       cache_target *= 2;
    591     }
    592   } else if (cache_size_ > 0) {
    593     FSTERROR() << "CacheImpl:GC: Unable to free all cached states";
    594   }
    595   VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this
    596           << "), free recently cached = " << free_recent
    597           << ", cache size = " << cache_size_
    598           << ", cache frac = " << cache_fraction
    599           << ", cache limit = " << cache_limit_ << "\n";
    600 }
    601 
    602 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheFinal;
    603 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheArcs;
    604 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheRecent;
    605 template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheModified;
    606 template <class S, class C> const size_t CacheBaseImpl<S, C>::kMinCacheLimit;
    607 
    608 // Arcs implemented by an STL vector per state. Similar to VectorState
    609 // but adds flags and ref count to keep track of what has been cached.
    610 template <class A>
    611 struct CacheState {
    612   typedef A Arc;
    613   typedef typename A::Weight Weight;
    614   typedef typename A::StateId StateId;
    615 
    616   CacheState() :  final(Weight::Zero()), flags(0), ref_count(0) {}
    617 
    618   void Reset() {
    619     flags = 0;
    620     ref_count = 0;
    621     arcs.resize(0);
    622   }
    623 
    624   Weight final;              // Final weight
    625   vector<A> arcs;            // Arcs represenation
    626   size_t niepsilons;         // # of input epsilons
    627   size_t noepsilons;         // # of output epsilons
    628   mutable uint32 flags;
    629   mutable int ref_count;
    630 };
    631 
    632 // A CacheBaseImpl with a commonly used CacheState.
    633 template <class A>
    634 class CacheImpl : public CacheBaseImpl< CacheState<A> > {
    635  public:
    636   typedef CacheState<A> State;
    637 
    638   CacheImpl() {}
    639 
    640   explicit CacheImpl(const CacheOptions &opts)
    641       : CacheBaseImpl< CacheState<A> >(opts) {}
    642 
    643   CacheImpl(const CacheImpl<A> &impl, bool preserve_cache = false)
    644       : CacheBaseImpl<State>(impl, preserve_cache) {}
    645 
    646  private:
    647   void operator=(const CacheImpl<State> &impl);    // disallow
    648 };
    649 
    650 
    651 // Use this to make a state iterator for a CacheBaseImpl-derived Fst,
    652 // which must have type 'State' defined.  Note this iterator only
    653 // returns those states reachable from the initial state, so consider
    654 // implementing a class-specific one.
    655 template <class F>
    656 class CacheStateIterator : public StateIteratorBase<typename F::Arc> {
    657  public:
    658   typedef typename F::Arc Arc;
    659   typedef typename Arc::StateId StateId;
    660   typedef typename F::State State;
    661   typedef CacheBaseImpl<State> Impl;
    662 
    663   CacheStateIterator(const F &fst, Impl *impl)
    664       : fst_(fst), impl_(impl), s_(0) {
    665         fst_.Start();  // force start state
    666       }
    667 
    668   bool Done() const {
    669     if (s_ < impl_->NumKnownStates())
    670       return false;
    671     if (s_ < impl_->NumKnownStates())
    672       return false;
    673     for (StateId u = impl_->MinUnexpandedState();
    674          u < impl_->NumKnownStates();
    675          u = impl_->MinUnexpandedState()) {
    676       // force state expansion
    677       ArcIterator<F> aiter(fst_, u);
    678       aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache);
    679       for (; !aiter.Done(); aiter.Next())
    680         impl_->UpdateNumKnownStates(aiter.Value().nextstate);
    681       impl_->ExpandedState(u);
    682       if (s_ < impl_->NumKnownStates())
    683         return false;
    684     }
    685     return true;
    686   }
    687 
    688   StateId Value() const { return s_; }
    689 
    690   void Next() { ++s_; }
    691 
    692   void Reset() { s_ = 0; }
    693 
    694  private:
    695   // This allows base class virtual access to non-virtual derived-
    696   // class members of the same name. It makes the derived class more
    697   // efficient to use but unsafe to further derive.
    698   virtual bool Done_() const { return Done(); }
    699   virtual StateId Value_() const { return Value(); }
    700   virtual void Next_() { Next(); }
    701   virtual void Reset_() { Reset(); }
    702 
    703   const F &fst_;
    704   Impl *impl_;
    705   StateId s_;
    706 };
    707 
    708 
    709 // Use this to make an arc iterator for a CacheBaseImpl-derived Fst,
    710 // which must have types 'Arc' and 'State' defined.
    711 template <class F,
    712           class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > >
    713 class CacheArcIterator {
    714  public:
    715   typedef typename F::Arc Arc;
    716   typedef typename F::State State;
    717   typedef typename Arc::StateId StateId;
    718   typedef CacheBaseImpl<State, C> Impl;
    719 
    720   CacheArcIterator(Impl *impl, StateId s) : i_(0) {
    721     state_ = impl->ExtendState(s);
    722     ++state_->ref_count;
    723   }
    724 
    725   ~CacheArcIterator() { --state_->ref_count;  }
    726 
    727   bool Done() const { return i_ >= state_->arcs.size(); }
    728 
    729   const Arc& Value() const { return state_->arcs[i_]; }
    730 
    731   void Next() { ++i_; }
    732 
    733   size_t Position() const { return i_; }
    734 
    735   void Reset() { i_ = 0; }
    736 
    737   void Seek(size_t a) { i_ = a; }
    738 
    739   uint32 Flags() const {
    740     return kArcValueFlags;
    741   }
    742 
    743   void SetFlags(uint32 flags, uint32 mask) {}
    744 
    745  private:
    746   const State *state_;
    747   size_t i_;
    748 
    749   DISALLOW_COPY_AND_ASSIGN(CacheArcIterator);
    750 };
    751 
    752 // Use this to make a mutable arc iterator for a CacheBaseImpl-derived Fst,
    753 // which must have types 'Arc' and 'State' defined.
    754 template <class F,
    755           class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > >
    756 class CacheMutableArcIterator
    757     : public MutableArcIteratorBase<typename F::Arc> {
    758  public:
    759   typedef typename F::State State;
    760   typedef typename F::Arc Arc;
    761   typedef typename Arc::StateId StateId;
    762   typedef typename Arc::Weight Weight;
    763   typedef CacheBaseImpl<State, C> Impl;
    764 
    765   // You will need to call MutateCheck() in the constructor.
    766   CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) {
    767     state_ = impl_->ExtendState(s_);
    768     ++state_->ref_count;
    769   };
    770 
    771   ~CacheMutableArcIterator() {
    772     --state_->ref_count;
    773   }
    774 
    775   bool Done() const { return i_ >= state_->arcs.size(); }
    776 
    777   const Arc& Value() const { return state_->arcs[i_]; }
    778 
    779   void Next() { ++i_; }
    780 
    781   size_t Position() const { return i_; }
    782 
    783   void Reset() { i_ = 0; }
    784 
    785   void Seek(size_t a) { i_ = a; }
    786 
    787   void SetValue(const Arc& arc) {
    788     state_->flags |= CacheBaseImpl<State, C>::kCacheModified;
    789     uint64 properties = impl_->Properties();
    790     Arc& oarc = state_->arcs[i_];
    791     if (oarc.ilabel != oarc.olabel)
    792       properties &= ~kNotAcceptor;
    793     if (oarc.ilabel == 0) {
    794       --state_->niepsilons;
    795       properties &= ~kIEpsilons;
    796       if (oarc.olabel == 0)
    797         properties &= ~kEpsilons;
    798     }
    799     if (oarc.olabel == 0) {
    800       --state_->noepsilons;
    801       properties &= ~kOEpsilons;
    802     }
    803     if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One())
    804       properties &= ~kWeighted;
    805     oarc = arc;
    806     if (arc.ilabel != arc.olabel) {
    807       properties |= kNotAcceptor;
    808       properties &= ~kAcceptor;
    809     }
    810     if (arc.ilabel == 0) {
    811       ++state_->niepsilons;
    812       properties |= kIEpsilons;
    813       properties &= ~kNoIEpsilons;
    814       if (arc.olabel == 0) {
    815         properties |= kEpsilons;
    816         properties &= ~kNoEpsilons;
    817       }
    818     }
    819     if (arc.olabel == 0) {
    820       ++state_->noepsilons;
    821       properties |= kOEpsilons;
    822       properties &= ~kNoOEpsilons;
    823     }
    824     if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) {
    825       properties |= kWeighted;
    826       properties &= ~kUnweighted;
    827     }
    828     properties &= kSetArcProperties | kAcceptor | kNotAcceptor |
    829         kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons |
    830         kOEpsilons | kNoOEpsilons | kWeighted | kUnweighted;
    831     impl_->SetProperties(properties);
    832   }
    833 
    834   uint32 Flags() const {
    835     return kArcValueFlags;
    836   }
    837 
    838   void SetFlags(uint32 f, uint32 m) {}
    839 
    840  private:
    841   virtual bool Done_() const { return Done(); }
    842   virtual const Arc& Value_() const { return Value(); }
    843   virtual void Next_() { Next(); }
    844   virtual size_t Position_() const { return Position(); }
    845   virtual void Reset_() { Reset(); }
    846   virtual void Seek_(size_t a) { Seek(a); }
    847   virtual void SetValue_(const Arc &a) { SetValue(a); }
    848   uint32 Flags_() const { return Flags(); }
    849   void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); }
    850 
    851   size_t i_;
    852   StateId s_;
    853   Impl *impl_;
    854   State *state_;
    855 
    856   DISALLOW_COPY_AND_ASSIGN(CacheMutableArcIterator);
    857 };
    858 
    859 }  // namespace fst
    860 
    861 #endif  // FST_LIB_CACHE_H__
    862