Home | History | Annotate | Download | only in lib
      1 // cache.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file
     17 // An Fst implementation that caches FST elements of a delayed
     18 // computation.
     19 
     20 #ifndef FST_LIB_CACHE_H__
     21 #define FST_LIB_CACHE_H__
     22 
     23 #include <list>
     24 
     25 #include "fst/lib/vector-fst.h"
     26 
     27 DECLARE_bool(fst_default_cache_gc);
     28 DECLARE_int64(fst_default_cache_gc_limit);
     29 
     30 namespace fst {
     31 
     32 struct CacheOptions {
     33   bool gc;          // enable GC
     34   size_t gc_limit;  // # of bytes allowed before GC
     35 
     36 
     37   CacheOptions(bool g, size_t l) : gc(g), gc_limit(l) {}
     38   CacheOptions()
     39       : gc(FLAGS_fst_default_cache_gc),
     40         gc_limit(FLAGS_fst_default_cache_gc_limit) {}
     41 };
     42 
     43 
     44 // This is a VectorFstBaseImpl container that holds a State similar to
     45 // VectorState but additionally has a flags data member (see
     46 // CacheState below). This class is used to cache FST elements with
     47 // the flags used to indicate what has been cached. Use HasStart()
     48 // HasFinal(), and HasArcs() to determine if cached and SetStart(),
     49 // SetFinal(), AddArc(), and SetArcs() to cache. Note you must set the
     50 // final weight even if the state is non-final to mark it as
     51 // cached. If the 'gc' option is 'false', cached items have the extent
     52 // of the FST - minimizing computation. If the 'gc' option is 'true',
     53 // garbage collection of states (not in use in an arc iterator) is
     54 // performed, in a rough approximation of LRU order, when 'gc_limit'
     55 // bytes is reached - controlling memory use. When 'gc_limit' is 0,
     56 // special optimizations apply - minimizing memory use.
     57 
     58 template <class S>
     59 class CacheBaseImpl : public VectorFstBaseImpl<S> {
     60  public:
     61   using FstImpl<typename S::Arc>::Type;
     62   using VectorFstBaseImpl<S>::NumStates;
     63   using VectorFstBaseImpl<S>::AddState;
     64 
     65   typedef S State;
     66   typedef typename S::Arc Arc;
     67   typedef typename Arc::Weight Weight;
     68   typedef typename Arc::StateId StateId;
     69 
     70   CacheBaseImpl()
     71       : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0),
     72         cache_first_state_id_(kNoStateId), cache_first_state_(0),
     73         cache_gc_(FLAGS_fst_default_cache_gc),  cache_size_(0),
     74         cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit ||
     75                      FLAGS_fst_default_cache_gc_limit == 0 ?
     76                      FLAGS_fst_default_cache_gc_limit : kMinCacheLimit) {}
     77 
     78   explicit CacheBaseImpl(const CacheOptions &opts)
     79       : cache_start_(false), nknown_states_(0),
     80         min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
     81         cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0),
     82         cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ?
     83                      opts.gc_limit : kMinCacheLimit) {}
     84 
     85   ~CacheBaseImpl() {
     86     delete cache_first_state_;
     87   }
     88 
     89   // Gets a state from its ID; state must exist.
     90   const S *GetState(StateId s) const {
     91     if (s == cache_first_state_id_)
     92       return cache_first_state_;
     93     else
     94       return VectorFstBaseImpl<S>::GetState(s);
     95   }
     96 
     97   // Gets a state from its ID; state must exist.
     98   S *GetState(StateId s) {
     99     if (s == cache_first_state_id_)
    100       return cache_first_state_;
    101     else
    102       return VectorFstBaseImpl<S>::GetState(s);
    103   }
    104 
    105   // Gets a state from its ID; return 0 if it doesn't exist.
    106   const S *CheckState(StateId s) const {
    107     if (s == cache_first_state_id_)
    108       return cache_first_state_;
    109     else if (s < NumStates())
    110       return VectorFstBaseImpl<S>::GetState(s);
    111     else
    112       return 0;
    113   }
    114 
    115   // Gets a state from its ID; add it if necessary.
    116   S *ExtendState(StateId s) {
    117     if (s == cache_first_state_id_) {
    118       return cache_first_state_;                   // Return 1st cached state
    119     } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) {
    120       cache_first_state_id_ = s;                   // Remember 1st cached state
    121       cache_first_state_ = new S;
    122       return cache_first_state_;
    123     } else if (cache_first_state_id_ != kNoStateId &&
    124                cache_first_state_->ref_count == 0) {
    125       cache_first_state_id_ = s;                   // Reuse 1st cached state
    126       cache_first_state_->Reset();
    127       return cache_first_state_;                   // Return 1st cached state
    128     } else {
    129       while (NumStates() <= s)                     // Add state to main cache
    130         AddState(0);
    131       if (!VectorFstBaseImpl<S>::GetState(s)) {
    132         SetState(s, new S);
    133         if (cache_first_state_id_ != kNoStateId) {  // Forget 1st cached state
    134           while (NumStates() <= cache_first_state_id_)
    135             AddState(0);
    136           SetState(cache_first_state_id_, cache_first_state_);
    137           if (cache_gc_) {
    138             cache_states_.push_back(cache_first_state_id_);
    139             cache_size_ += sizeof(S) +
    140                            cache_first_state_->arcs.capacity() * sizeof(Arc);
    141             cache_limit_ = kMinCacheLimit;
    142           }
    143           cache_first_state_id_ = kNoStateId;
    144           cache_first_state_ = 0;
    145         }
    146         if (cache_gc_) {
    147           cache_states_.push_back(s);
    148           cache_size_ += sizeof(S);
    149           if (cache_size_ > cache_limit_)
    150             GC(s, false);
    151         }
    152       }
    153       return VectorFstBaseImpl<S>::GetState(s);
    154     }
    155   }
    156 
    157   void SetStart(StateId s) {
    158     VectorFstBaseImpl<S>::SetStart(s);
    159     cache_start_ = true;
    160     if (s >= nknown_states_)
    161       nknown_states_ = s + 1;
    162   }
    163 
    164   void SetFinal(StateId s, Weight w) {
    165     S *state = ExtendState(s);
    166     state->final = w;
    167     state->flags |= kCacheFinal | kCacheRecent;
    168   }
    169 
    170   void AddArc(StateId s, const Arc &arc) {
    171     S *state = ExtendState(s);
    172     state->arcs.push_back(arc);
    173   }
    174 
    175   // Marks arcs of state s as cached.
    176   void SetArcs(StateId s) {
    177     S *state = ExtendState(s);
    178     vector<Arc> &arcs = state->arcs;
    179     state->niepsilons = state->noepsilons = 0;
    180     for (unsigned int a = 0; a < arcs.size(); ++a) {
    181       const Arc &arc = arcs[a];
    182       if (arc.nextstate >= nknown_states_)
    183         nknown_states_ = arc.nextstate + 1;
    184       if (arc.ilabel == 0)
    185         ++state->niepsilons;
    186       if (arc.olabel == 0)
    187         ++state->noepsilons;
    188     }
    189     ExpandedState(s);
    190     state->flags |= kCacheArcs | kCacheRecent;
    191     if (cache_gc_ && s != cache_first_state_id_) {
    192       cache_size_ += arcs.capacity() * sizeof(Arc);
    193       if (cache_size_ > cache_limit_)
    194         GC(s, false);
    195     }
    196   };
    197 
    198   void ReserveArcs(StateId s, size_t n) {
    199     S *state = ExtendState(s);
    200     state->arcs.reserve(n);
    201   }
    202 
    203   // Is the start state cached?
    204   bool HasStart() const { return cache_start_; }
    205   // Is the final weight of state s cached?
    206 
    207   bool HasFinal(StateId s) const {
    208     const S *state = CheckState(s);
    209     if (state && state->flags & kCacheFinal) {
    210       state->flags |= kCacheRecent;
    211       return true;
    212     } else {
    213       return false;
    214     }
    215   }
    216 
    217   // Are arcs of state s cached?
    218   bool HasArcs(StateId s) const {
    219     const S *state = CheckState(s);
    220     if (state && state->flags & kCacheArcs) {
    221       state->flags |= kCacheRecent;
    222       return true;
    223     } else {
    224       return false;
    225     }
    226   }
    227 
    228   Weight Final(StateId s) const {
    229     const S *state = GetState(s);
    230     return state->final;
    231   }
    232 
    233   size_t NumArcs(StateId s) const {
    234     const S *state = GetState(s);
    235     return state->arcs.size();
    236   }
    237 
    238   size_t NumInputEpsilons(StateId s) const {
    239     const S *state = GetState(s);
    240     return state->niepsilons;
    241   }
    242 
    243   size_t NumOutputEpsilons(StateId s) const {
    244     const S *state = GetState(s);
    245     return state->noepsilons;
    246   }
    247 
    248   // Provides information needed for generic arc iterator.
    249   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
    250     const S *state = GetState(s);
    251     data->base = 0;
    252     data->narcs = state->arcs.size();
    253     data->arcs = data->narcs > 0 ? &(state->arcs[0]) : 0;
    254     data->ref_count = &(state->ref_count);
    255     ++(*data->ref_count);
    256   }
    257 
    258   // Number of known states.
    259   StateId NumKnownStates() const { return nknown_states_; }
    260   // Find the mininum never-expanded state Id
    261   StateId MinUnexpandedState() const {
    262     while (min_unexpanded_state_id_ < (StateId)expanded_states_.size() &&
    263           expanded_states_[min_unexpanded_state_id_])
    264       ++min_unexpanded_state_id_;
    265     return min_unexpanded_state_id_;
    266   }
    267 
    268   // Removes from cache_states_ and uncaches (not referenced-counted)
    269   // states that have not been accessed since the last GC until
    270   // cache_limit_/3 bytes are uncached.  If that fails to free enough,
    271   // recurs uncaching recently visited states as well. If still
    272   // unable to free enough memory, then widens cache_limit_.
    273   void GC(StateId current, bool free_recent) {
    274     if (!cache_gc_)
    275       return;
    276     VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this
    277             << "), free recently cached = " << free_recent
    278             << ", cache size = " << cache_size_
    279             << ", cache limit = " << cache_limit_ << "\n";
    280     typename list<StateId>::iterator siter = cache_states_.begin();
    281 
    282     size_t cache_target = (2 * cache_limit_)/3 + 1;
    283     while (siter != cache_states_.end()) {
    284       StateId s = *siter;
    285       S* state = VectorFstBaseImpl<S>::GetState(s);
    286       if (cache_size_ > cache_target && state->ref_count == 0 &&
    287           (free_recent || !(state->flags & kCacheRecent)) && s != current) {
    288         cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc);
    289         delete state;
    290         SetState(s, 0);
    291         cache_states_.erase(siter++);
    292       } else {
    293         state->flags &= ~kCacheRecent;
    294         ++siter;
    295       }
    296     }
    297     if (!free_recent && cache_size_ > cache_target) {
    298       GC(current, true);
    299     } else {
    300       while (cache_size_ > cache_target) {
    301         cache_limit_ *= 2;
    302         cache_target *= 2;
    303       }
    304     }
    305     VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this
    306             << "), free recently cached = " << free_recent
    307             << ", cache size = " << cache_size_
    308             << ", cache limit = " << cache_limit_ << "\n";
    309   }
    310 
    311  private:
    312   static const uint32 kCacheFinal =  0x0001;  // Final weight has been cached
    313   static const uint32 kCacheArcs =   0x0002;  // Arcs have been cached
    314   static const uint32 kCacheRecent = 0x0004;  // Mark as visited since GC
    315 
    316   static const size_t kMinCacheLimit;         // Minimum (non-zero) cache limit
    317 
    318   void ExpandedState(StateId s) {
    319     if (s < min_unexpanded_state_id_)
    320       return;
    321     while ((StateId)expanded_states_.size() <= s)
    322       expanded_states_.push_back(false);
    323     expanded_states_[s] = true;
    324   }
    325 
    326   bool cache_start_;                         // Is the start state cached?
    327   StateId nknown_states_;                    // # of known states
    328   vector<bool> expanded_states_;             // states that have been expanded
    329   mutable StateId min_unexpanded_state_id_;  // minimum never-expanded state Id
    330   StateId cache_first_state_id_;             // First cached state id
    331   S *cache_first_state_;                     // First cached state
    332   list<StateId> cache_states_;               // list of currently cached states
    333   bool cache_gc_;                            // enable GC
    334   size_t cache_size_;                        // # of bytes cached
    335   size_t cache_limit_;                       // # of bytes allowed before GC
    336 
    337   void InitStateIterator(StateIteratorData<Arc> *);  // disallow
    338   DISALLOW_EVIL_CONSTRUCTORS(CacheBaseImpl);
    339 };
    340 
    341 template <class S>
    342 const size_t CacheBaseImpl<S>::kMinCacheLimit = 8096;
    343 
    344 
    345 // Arcs implemented by an STL vector per state. Similar to VectorState
    346 // but adds flags and ref count to keep track of what has been cached.
    347 template <class A>
    348 struct CacheState {
    349   typedef A Arc;
    350   typedef typename A::Weight Weight;
    351   typedef typename A::StateId StateId;
    352 
    353   CacheState() :  final(Weight::Zero()), flags(0), ref_count(0) {}
    354 
    355   void Reset() {
    356     flags = 0;
    357     ref_count = 0;
    358     arcs.resize(0);
    359   }
    360 
    361   Weight final;              // Final weight
    362   vector<A> arcs;            // Arcs represenation
    363   size_t niepsilons;         // # of input epsilons
    364   size_t noepsilons;         // # of output epsilons
    365   mutable uint32 flags;
    366   mutable int ref_count;
    367 };
    368 
    369 // A CacheBaseImpl with a commonly used CacheState.
    370 template <class A>
    371 class CacheImpl : public CacheBaseImpl< CacheState<A> > {
    372  public:
    373   typedef CacheState<A> State;
    374 
    375   CacheImpl() {}
    376 
    377   explicit CacheImpl(const CacheOptions &opts)
    378       : CacheBaseImpl< CacheState<A> >(opts) {}
    379 
    380  private:
    381   DISALLOW_EVIL_CONSTRUCTORS(CacheImpl);
    382 };
    383 
    384 
    385 // Use this to make a state iterator for a CacheBaseImpl-derived Fst.
    386 // You'll need to make this class a friend of your derived Fst.
    387 // Note this iterator only returns those states reachable from
    388 // the initial state, so consider implementing a class-specific one.
    389 template <class F>
    390 class CacheStateIterator : public StateIteratorBase<typename F::Arc> {
    391  public:
    392   typedef typename F::Arc Arc;
    393   typedef typename Arc::StateId StateId;
    394 
    395   explicit CacheStateIterator(const F &fst) : fst_(fst), s_(0) {}
    396 
    397   virtual bool Done() const {
    398     if (s_ < fst_.impl_->NumKnownStates())
    399       return false;
    400     fst_.Start();  // force start state
    401     if (s_ < fst_.impl_->NumKnownStates())
    402       return false;
    403     for (int u = fst_.impl_->MinUnexpandedState();
    404          u < fst_.impl_->NumKnownStates();
    405          u = fst_.impl_->MinUnexpandedState()) {
    406       ArcIterator<F>(fst_, u);  // force state expansion
    407       if (s_ < fst_.impl_->NumKnownStates())
    408         return false;
    409     }
    410     return true;
    411   }
    412 
    413   virtual StateId Value() const { return s_; }
    414 
    415   virtual void Next() { ++s_; }
    416 
    417   virtual void Reset() { s_ = 0; }
    418 
    419  private:
    420   const F &fst_;
    421   StateId s_;
    422 };
    423 
    424 
    425 // Use this to make an arc iterator for a CacheBaseImpl-derived Fst.
    426 // You'll need to make this class a friend of your derived Fst and
    427 // define types Arc and State.
    428 template <class F>
    429 class CacheArcIterator {
    430  public:
    431   typedef typename F::Arc Arc;
    432   typedef typename F::State State;
    433   typedef typename Arc::StateId StateId;
    434 
    435   CacheArcIterator(const F &fst, StateId s) : i_(0) {
    436     state_ = fst.impl_->ExtendState(s);
    437     ++state_->ref_count;
    438   }
    439 
    440   ~CacheArcIterator() { --state_->ref_count;  }
    441 
    442   bool Done() const { return i_ >= state_->arcs.size(); }
    443 
    444   const Arc& Value() const { return state_->arcs[i_]; }
    445 
    446   void Next() { ++i_; }
    447 
    448   void Reset() { i_ = 0; }
    449 
    450   void Seek(size_t a) { i_ = a; }
    451 
    452  private:
    453   const State *state_;
    454   size_t i_;
    455 
    456   DISALLOW_EVIL_CONSTRUCTORS(CacheArcIterator);
    457 };
    458 
    459 }  // namespace fst
    460 
    461 #endif  // FST_LIB_CACHE_H__
    462