Home | History | Annotate | Download | only in fst
      1 // accumulator.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Classes to accumulate arc weights. Useful for weight lookahead.
     20 
     21 #ifndef FST_LIB_ACCUMULATOR_H__
     22 #define FST_LIB_ACCUMULATOR_H__
     23 
     24 #include <algorithm>
     25 #include <functional>
     26 #include <tr1/unordered_map>
     27 using std::tr1::unordered_map;
     28 using std::tr1::unordered_multimap;
     29 #include <vector>
     30 using std::vector;
     31 
     32 #include <fst/arcfilter.h>
     33 #include <fst/arcsort.h>
     34 #include <fst/dfs-visit.h>
     35 #include <fst/expanded-fst.h>
     36 #include <fst/replace.h>
     37 
     38 namespace fst {
     39 
     40 // This class accumulates arc weights using the semiring Plus().
     41 template <class A>
     42 class DefaultAccumulator {
     43  public:
     44   typedef A Arc;
     45   typedef typename A::StateId StateId;
     46   typedef typename A::Weight Weight;
     47 
     48   DefaultAccumulator() {}
     49 
     50   DefaultAccumulator(const DefaultAccumulator<A> &acc) {}
     51 
     52   void Init(const Fst<A>& fst, bool copy = false) {}
     53 
     54   void SetState(StateId) {}
     55 
     56   Weight Sum(Weight w, Weight v) {
     57     return Plus(w, v);
     58   }
     59 
     60   template <class ArcIterator>
     61   Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
     62              ssize_t end) {
     63     Weight sum = w;
     64     aiter->Seek(begin);
     65     for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
     66       sum = Plus(sum, aiter->Value().weight);
     67     return sum;
     68   }
     69 
     70   bool Error() const { return false; }
     71 
     72  private:
     73   void operator=(const DefaultAccumulator<A> &);   // Disallow
     74 };
     75 
     76 
     77 // This class accumulates arc weights using the log semiring Plus()
     78 // assuming an arc weight has a WeightConvert specialization to
     79 // and from log64 weights.
     80 template <class A>
     81 class LogAccumulator {
     82  public:
     83   typedef A Arc;
     84   typedef typename A::StateId StateId;
     85   typedef typename A::Weight Weight;
     86 
     87   LogAccumulator() {}
     88 
     89   LogAccumulator(const LogAccumulator<A> &acc) {}
     90 
     91   void Init(const Fst<A>& fst, bool copy = false) {}
     92 
     93   void SetState(StateId) {}
     94 
     95   Weight Sum(Weight w, Weight v) {
     96     return LogPlus(w, v);
     97   }
     98 
     99   template <class ArcIterator>
    100   Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
    101              ssize_t end) {
    102     Weight sum = w;
    103     aiter->Seek(begin);
    104     for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
    105       sum = LogPlus(sum, aiter->Value().weight);
    106     return sum;
    107   }
    108 
    109   bool Error() const { return false; }
    110 
    111  private:
    112   double LogPosExp(double x) { return log(1.0F + exp(-x)); }
    113 
    114   Weight LogPlus(Weight w, Weight v) {
    115     double f1 = to_log_weight_(w).Value();
    116     double f2 = to_log_weight_(v).Value();
    117     if (f1 > f2)
    118       return to_weight_(f2 - LogPosExp(f1 - f2));
    119     else
    120       return to_weight_(f1 - LogPosExp(f2 - f1));
    121   }
    122 
    123   WeightConvert<Weight, Log64Weight> to_log_weight_;
    124   WeightConvert<Log64Weight, Weight> to_weight_;
    125 
    126   void operator=(const LogAccumulator<A> &);   // Disallow
    127 };
    128 
    129 
    130 // Stores shareable data for fast log accumulator copies.
    131 class FastLogAccumulatorData {
    132  public:
    133   FastLogAccumulatorData() {}
    134 
    135   vector<double> *Weights() { return &weights_; }
    136   vector<ssize_t> *WeightPositions() { return &weight_positions_; }
    137   double *WeightEnd() { return &(weights_[weights_.size() - 1]); };
    138   int RefCount() const { return ref_count_.count(); }
    139   int IncrRefCount() { return ref_count_.Incr(); }
    140   int DecrRefCount() { return ref_count_.Decr(); }
    141 
    142  private:
    143   // Cummulative weight per state for all states s.t. # of arcs >
    144   // arc_limit_ with arcs in order. Special first element per state
    145   // being Log64Weight::Zero();
    146   vector<double> weights_;
    147   // Maps from state to corresponding beginning weight position in
    148   // weights_. Position -1 means no pre-computed weights for that
    149   // state.
    150   vector<ssize_t> weight_positions_;
    151   RefCounter ref_count_;                  // Reference count.
    152 
    153   DISALLOW_COPY_AND_ASSIGN(FastLogAccumulatorData);
    154 };
    155 
    156 
    157 // This class accumulates arc weights using the log semiring Plus()
    158 // assuming an arc weight has a WeightConvert specialization to and
    159 // from log64 weights. The member function Init(fst) has to be called
    160 // to setup pre-computed weight information.
    161 template <class A>
    162 class FastLogAccumulator {
    163  public:
    164   typedef A Arc;
    165   typedef typename A::StateId StateId;
    166   typedef typename A::Weight Weight;
    167 
    168   explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10)
    169       : arc_limit_(arc_limit),
    170         arc_period_(arc_period),
    171         data_(new FastLogAccumulatorData()),
    172         error_(false) {}
    173 
    174   FastLogAccumulator(const FastLogAccumulator<A> &acc)
    175       : arc_limit_(acc.arc_limit_),
    176         arc_period_(acc.arc_period_),
    177         data_(acc.data_),
    178         error_(acc.error_) {
    179     data_->IncrRefCount();
    180   }
    181 
    182   ~FastLogAccumulator() {
    183     if (!data_->DecrRefCount())
    184       delete data_;
    185   }
    186 
    187   void SetState(StateId s) {
    188     vector<double> &weights = *data_->Weights();
    189     vector<ssize_t> &weight_positions = *data_->WeightPositions();
    190 
    191     if (weight_positions.size() <= s) {
    192       FSTERROR() << "FastLogAccumulator::SetState: invalid state id.";
    193       error_ = true;
    194       return;
    195     }
    196 
    197     ssize_t pos = weight_positions[s];
    198     if (pos >= 0)
    199       state_weights_ = &(weights[pos]);
    200     else
    201       state_weights_ = 0;
    202   }
    203 
    204   Weight Sum(Weight w, Weight v) {
    205     return LogPlus(w, v);
    206   }
    207 
    208   template <class ArcIterator>
    209   Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
    210              ssize_t end) {
    211     if (error_) return Weight::NoWeight();
    212     Weight sum = w;
    213     // Finds begin and end of pre-stored weights
    214     ssize_t index_begin = -1, index_end = -1;
    215     ssize_t stored_begin = end, stored_end = end;
    216     if (state_weights_ != 0) {
    217       index_begin = begin > 0 ? (begin - 1)/ arc_period_ + 1 : 0;
    218       index_end = end / arc_period_;
    219       stored_begin = index_begin * arc_period_;
    220       stored_end = index_end * arc_period_;
    221     }
    222     // Computes sum before pre-stored weights
    223     if (begin < stored_begin) {
    224       ssize_t pos_end = min(stored_begin, end);
    225       aiter->Seek(begin);
    226       for (ssize_t pos = begin; pos < pos_end; aiter->Next(), ++pos)
    227         sum = LogPlus(sum, aiter->Value().weight);
    228     }
    229     // Computes sum between pre-stored weights
    230     if (stored_begin < stored_end) {
    231       sum = LogPlus(sum, LogMinus(state_weights_[index_end],
    232                                   state_weights_[index_begin]));
    233     }
    234     // Computes sum after pre-stored weights
    235     if (stored_end < end) {
    236       ssize_t pos_start = max(stored_begin, stored_end);
    237       aiter->Seek(pos_start);
    238       for (ssize_t pos = pos_start; pos < end; aiter->Next(), ++pos)
    239         sum = LogPlus(sum, aiter->Value().weight);
    240     }
    241     return sum;
    242   }
    243 
    244   template <class F>
    245   void Init(const F &fst, bool copy = false) {
    246     if (copy)
    247       return;
    248     vector<double> &weights = *data_->Weights();
    249     vector<ssize_t> &weight_positions = *data_->WeightPositions();
    250     if (!weights.empty() || arc_limit_ < arc_period_) {
    251       FSTERROR() << "FastLogAccumulator: initialization error.";
    252       error_ = true;
    253       return;
    254     }
    255     weight_positions.reserve(CountStates(fst));
    256 
    257     ssize_t weight_position = 0;
    258     for(StateIterator<F> siter(fst); !siter.Done(); siter.Next()) {
    259       StateId s = siter.Value();
    260       if (fst.NumArcs(s) >= arc_limit_) {
    261         double sum = FloatLimits<double>::PosInfinity();
    262         weight_positions.push_back(weight_position);
    263         weights.push_back(sum);
    264         ++weight_position;
    265         ssize_t narcs = 0;
    266         for(ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) {
    267           const A &arc = aiter.Value();
    268           sum = LogPlus(sum, arc.weight);
    269           // Stores cumulative weight distribution per arc_period_.
    270           if (++narcs % arc_period_ == 0) {
    271             weights.push_back(sum);
    272             ++weight_position;
    273           }
    274         }
    275       } else {
    276         weight_positions.push_back(-1);
    277       }
    278     }
    279   }
    280 
    281   bool Error() const { return error_; }
    282 
    283  private:
    284   double LogPosExp(double x) {
    285     return x == FloatLimits<double>::PosInfinity() ?
    286         0.0 : log(1.0F + exp(-x));
    287   }
    288 
    289   double LogMinusExp(double x) {
    290     return x == FloatLimits<double>::PosInfinity() ?
    291         0.0 : log(1.0F - exp(-x));
    292   }
    293 
    294   Weight LogPlus(Weight w, Weight v) {
    295     double f1 = to_log_weight_(w).Value();
    296     double f2 = to_log_weight_(v).Value();
    297     if (f1 > f2)
    298       return to_weight_(f2 - LogPosExp(f1 - f2));
    299     else
    300       return to_weight_(f1 - LogPosExp(f2 - f1));
    301   }
    302 
    303   double LogPlus(double f1, Weight v) {
    304     double f2 = to_log_weight_(v).Value();
    305     if (f1 == FloatLimits<double>::PosInfinity())
    306       return f2;
    307     else if (f1 > f2)
    308       return f2 - LogPosExp(f1 - f2);
    309     else
    310       return f1 - LogPosExp(f2 - f1);
    311   }
    312 
    313   Weight LogMinus(double f1, double f2) {
    314     if (f1 >= f2) {
    315       FSTERROR() << "FastLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1
    316                  << " and f2 = " << f2;
    317       error_ = true;
    318       return Weight::NoWeight();
    319     }
    320     if (f2 == FloatLimits<double>::PosInfinity())
    321       return to_weight_(f1);
    322     else
    323       return to_weight_(f1 - LogMinusExp(f2 - f1));
    324   }
    325 
    326   WeightConvert<Weight, Log64Weight> to_log_weight_;
    327   WeightConvert<Log64Weight, Weight> to_weight_;
    328 
    329   ssize_t arc_limit_;     // Minimum # of arcs to pre-compute state
    330   ssize_t arc_period_;    // Save cumulative weights per 'arc_period_'.
    331   bool init_;             // Cumulative weights initialized?
    332   FastLogAccumulatorData *data_;
    333   double *state_weights_;
    334   bool error_;
    335 
    336   void operator=(const FastLogAccumulator<A> &);   // Disallow
    337 };
    338 
    339 
    340 // Stores shareable data for cache log accumulator copies.
    341 // All copies share the same cache.
    342 template <class A>
    343 class CacheLogAccumulatorData {
    344  public:
    345   typedef A Arc;
    346   typedef typename A::StateId StateId;
    347   typedef typename A::Weight Weight;
    348 
    349   CacheLogAccumulatorData(bool gc, size_t gc_limit)
    350       : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {}
    351 
    352   ~CacheLogAccumulatorData() {
    353     for(typename unordered_map<StateId, CacheState>::iterator it = cache_.begin();
    354         it != cache_.end();
    355         ++it)
    356       delete it->second.weights;
    357   }
    358 
    359   bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; }
    360 
    361   vector<double> *GetWeights(StateId s) {
    362     typename unordered_map<StateId, CacheState>::iterator it = cache_.find(s);
    363     if (it != cache_.end()) {
    364       it->second.recent = true;
    365       return it->second.weights;
    366     } else {
    367       return 0;
    368     }
    369   }
    370 
    371   void AddWeights(StateId s, vector<double> *weights) {
    372     if (cache_gc_ && cache_size_ >= cache_limit_)
    373       GC(false);
    374     cache_.insert(make_pair(s, CacheState(weights, true)));
    375     if (cache_gc_)
    376       cache_size_ += weights->capacity() * sizeof(double);
    377   }
    378 
    379   int RefCount() const { return ref_count_.count(); }
    380   int IncrRefCount() { return ref_count_.Incr(); }
    381   int DecrRefCount() { return ref_count_.Decr(); }
    382 
    383  private:
    384   // Cached information for a given state.
    385   struct CacheState {
    386     vector<double>* weights;  // Accumulated weights for this state.
    387     bool recent;              // Has this state been accessed since last GC?
    388 
    389     CacheState(vector<double> *w, bool r) : weights(w), recent(r) {}
    390   };
    391 
    392   // Garbage collect: Delete from cache states that have not been
    393   // accessed since the last GC ('free_recent = false') until
    394   // 'cache_size_' is 2/3 of 'cache_limit_'. If it does not free enough
    395   // memory, start deleting recently accessed states.
    396   void GC(bool free_recent) {
    397     size_t cache_target = (2 * cache_limit_)/3 + 1;
    398     typename unordered_map<StateId, CacheState>::iterator it = cache_.begin();
    399     while (it != cache_.end() && cache_size_ > cache_target) {
    400       CacheState &cs = it->second;
    401       if (free_recent || !cs.recent) {
    402         cache_size_ -= cs.weights->capacity() * sizeof(double);
    403         delete cs.weights;
    404         cache_.erase(it++);
    405       } else {
    406         cs.recent = false;
    407         ++it;
    408       }
    409     }
    410     if (!free_recent && cache_size_ > cache_target)
    411       GC(true);
    412   }
    413 
    414   unordered_map<StateId, CacheState> cache_;  // Cache
    415   bool cache_gc_;                        // Enable garbage collection
    416   size_t cache_limit_;                   // # of bytes cached
    417   size_t cache_size_;                    // # of bytes allowed before GC
    418   RefCounter ref_count_;
    419 
    420   DISALLOW_COPY_AND_ASSIGN(CacheLogAccumulatorData);
    421 };
    422 
    423 // This class accumulates arc weights using the log semiring Plus()
    424 //  has a WeightConvert specialization to and from log64 weights.  It
    425 //  is similar to the FastLogAccumator. However here, the accumulated
    426 //  weights are pre-computed and stored only for the states that are
    427 //  visited. The member function Init(fst) has to be called to setup
    428 //  this accumulator.
    429 template <class A>
    430 class CacheLogAccumulator {
    431  public:
    432   typedef A Arc;
    433   typedef typename A::StateId StateId;
    434   typedef typename A::Weight Weight;
    435 
    436   explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false,
    437                                size_t gc_limit = 10 * 1024 * 1024)
    438       : arc_limit_(arc_limit), fst_(0), data_(
    439           new CacheLogAccumulatorData<A>(gc, gc_limit)), s_(kNoStateId),
    440         error_(false) {}
    441 
    442   CacheLogAccumulator(const CacheLogAccumulator<A> &acc)
    443       : arc_limit_(acc.arc_limit_), fst_(acc.fst_ ? acc.fst_->Copy() : 0),
    444         data_(acc.data_), s_(kNoStateId), error_(acc.error_) {
    445     data_->IncrRefCount();
    446   }
    447 
    448   ~CacheLogAccumulator() {
    449     if (fst_)
    450       delete fst_;
    451     if (!data_->DecrRefCount())
    452       delete data_;
    453   }
    454 
    455   // Arg 'arc_limit' specifies minimum # of arcs to pre-compute state.
    456   void Init(const Fst<A> &fst, bool copy = false) {
    457     if (copy) {
    458       delete fst_;
    459     } else if (fst_) {
    460       FSTERROR() << "CacheLogAccumulator: initialization error.";
    461       error_ = true;
    462       return;
    463     }
    464     fst_ = fst.Copy();
    465   }
    466 
    467   void SetState(StateId s, int depth = 0) {
    468     if (s == s_)
    469       return;
    470     s_ = s;
    471 
    472     if (data_->CacheDisabled() || error_) {
    473       weights_ = 0;
    474       return;
    475     }
    476 
    477     if (!fst_) {
    478       FSTERROR() << "CacheLogAccumulator::SetState: incorrectly initialized.";
    479       error_ = true;
    480       weights_ = 0;
    481       return;
    482     }
    483 
    484     weights_ = data_->GetWeights(s);
    485     if ((weights_ == 0) && (fst_->NumArcs(s) >= arc_limit_)) {
    486       weights_ = new vector<double>;
    487       weights_->reserve(fst_->NumArcs(s) + 1);
    488       weights_->push_back(FloatLimits<double>::PosInfinity());
    489       data_->AddWeights(s, weights_);
    490     }
    491   }
    492 
    493   Weight Sum(Weight w, Weight v) {
    494     return LogPlus(w, v);
    495   }
    496 
    497   template <class Iterator>
    498   Weight Sum(Weight w, Iterator *aiter, ssize_t begin,
    499              ssize_t end) {
    500     if (weights_ == 0) {
    501       Weight sum = w;
    502       aiter->Seek(begin);
    503       for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
    504         sum = LogPlus(sum, aiter->Value().weight);
    505       return sum;
    506     } else {
    507       if (weights_->size() <= end)
    508         for (aiter->Seek(weights_->size() - 1);
    509              weights_->size() <= end;
    510              aiter->Next())
    511           weights_->push_back(LogPlus(weights_->back(),
    512                                       aiter->Value().weight));
    513       return LogPlus(w, LogMinus((*weights_)[end], (*weights_)[begin]));
    514     }
    515   }
    516 
    517   template <class Iterator>
    518   size_t LowerBound(double w, Iterator *aiter) {
    519     if (weights_ != 0) {
    520       return lower_bound(weights_->begin() + 1,
    521                          weights_->end(),
    522                          w,
    523                          std::greater<double>())
    524           - weights_->begin() - 1;
    525     } else {
    526       size_t n = 0;
    527       double x =  FloatLimits<double>::PosInfinity();
    528       for(aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) {
    529         x = LogPlus(x, aiter->Value().weight);
    530         if (x < w) break;
    531       }
    532       return n;
    533     }
    534   }
    535 
    536   bool Error() const { return error_; }
    537 
    538  private:
    539   double LogPosExp(double x) {
    540     return x == FloatLimits<double>::PosInfinity() ?
    541         0.0 : log(1.0F + exp(-x));
    542   }
    543 
    544   double LogMinusExp(double x) {
    545     return x == FloatLimits<double>::PosInfinity() ?
    546         0.0 : log(1.0F - exp(-x));
    547   }
    548 
    549   Weight LogPlus(Weight w, Weight v) {
    550     double f1 = to_log_weight_(w).Value();
    551     double f2 = to_log_weight_(v).Value();
    552     if (f1 > f2)
    553       return to_weight_(f2 - LogPosExp(f1 - f2));
    554     else
    555       return to_weight_(f1 - LogPosExp(f2 - f1));
    556   }
    557 
    558   double LogPlus(double f1, Weight v) {
    559     double f2 = to_log_weight_(v).Value();
    560     if (f1 == FloatLimits<double>::PosInfinity())
    561       return f2;
    562     else if (f1 > f2)
    563       return f2 - LogPosExp(f1 - f2);
    564     else
    565       return f1 - LogPosExp(f2 - f1);
    566   }
    567 
    568   Weight LogMinus(double f1, double f2) {
    569     if (f1 >= f2) {
    570       FSTERROR() << "CacheLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1
    571                  << " and f2 = " << f2;
    572       error_ = true;
    573       return Weight::NoWeight();
    574     }
    575     if (f2 == FloatLimits<double>::PosInfinity())
    576       return to_weight_(f1);
    577     else
    578       return to_weight_(f1 - LogMinusExp(f2 - f1));
    579   }
    580 
    581   WeightConvert<Weight, Log64Weight> to_log_weight_;
    582   WeightConvert<Log64Weight, Weight> to_weight_;
    583 
    584   ssize_t arc_limit_;                    // Minimum # of arcs to cache a state
    585   vector<double> *weights_;              // Accumulated weights for cur. state
    586   const Fst<A>* fst_;                    // Input fst
    587   CacheLogAccumulatorData<A> *data_;     // Cache data
    588   StateId s_;                            // Current state
    589   bool error_;
    590 
    591   void operator=(const CacheLogAccumulator<A> &);   // Disallow
    592 };
    593 
    594 
    595 // Stores shareable data for replace accumulator copies.
    596 template <class Accumulator, class T>
    597 class ReplaceAccumulatorData {
    598  public:
    599   typedef typename Accumulator::Arc Arc;
    600   typedef typename Arc::StateId StateId;
    601   typedef typename Arc::Label Label;
    602   typedef T StateTable;
    603   typedef typename T::StateTuple StateTuple;
    604 
    605   ReplaceAccumulatorData() : state_table_(0) {}
    606 
    607   ReplaceAccumulatorData(const vector<Accumulator*> &accumulators)
    608       : state_table_(0), accumulators_(accumulators) {}
    609 
    610   ~ReplaceAccumulatorData() {
    611     for (size_t i = 0; i < fst_array_.size(); ++i)
    612       delete fst_array_[i];
    613     for (size_t i = 0; i < accumulators_.size(); ++i)
    614       delete accumulators_[i];
    615   }
    616 
    617   void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples,
    618        const StateTable *state_table) {
    619     state_table_ = state_table;
    620     accumulators_.resize(fst_tuples.size());
    621     for (size_t i = 0; i < accumulators_.size(); ++i) {
    622       if (!accumulators_[i])
    623         accumulators_[i] = new Accumulator;
    624       accumulators_[i]->Init(*(fst_tuples[i].second));
    625       fst_array_.push_back(fst_tuples[i].second->Copy());
    626     }
    627   }
    628 
    629   const StateTuple &GetTuple(StateId s) const {
    630     return state_table_->Tuple(s);
    631   }
    632 
    633   Accumulator *GetAccumulator(size_t i) { return accumulators_[i]; }
    634 
    635   const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i]; }
    636 
    637   int RefCount() const { return ref_count_.count(); }
    638   int IncrRefCount() { return ref_count_.Incr(); }
    639   int DecrRefCount() { return ref_count_.Decr(); }
    640 
    641  private:
    642   const T * state_table_;
    643   vector<Accumulator*> accumulators_;
    644   vector<const Fst<Arc>*> fst_array_;
    645   RefCounter ref_count_;
    646 
    647   DISALLOW_COPY_AND_ASSIGN(ReplaceAccumulatorData);
    648 };
    649 
    650 // This class accumulates weights in a ReplaceFst.  The 'Init' method
    651 // takes as input the argument used to build the ReplaceFst and the
    652 // ReplaceFst state table. It uses accumulators of type 'Accumulator'
    653 // in the underlying FSTs.
    654 template <class Accumulator,
    655           class T = DefaultReplaceStateTable<typename Accumulator::Arc> >
    656 class ReplaceAccumulator {
    657  public:
    658   typedef typename Accumulator::Arc Arc;
    659   typedef typename Arc::StateId StateId;
    660   typedef typename Arc::Label Label;
    661   typedef typename Arc::Weight Weight;
    662   typedef T StateTable;
    663   typedef typename T::StateTuple StateTuple;
    664 
    665   ReplaceAccumulator()
    666       : init_(false), data_(new ReplaceAccumulatorData<Accumulator, T>()),
    667         error_(false) {}
    668 
    669   ReplaceAccumulator(const vector<Accumulator*> &accumulators)
    670       : init_(false),
    671         data_(new ReplaceAccumulatorData<Accumulator, T>(accumulators)),
    672         error_(false) {}
    673 
    674   ReplaceAccumulator(const ReplaceAccumulator<Accumulator, T> &acc)
    675       : init_(acc.init_), data_(acc.data_), error_(acc.error_) {
    676     if (!init_)
    677       FSTERROR() << "ReplaceAccumulator: can't copy unintialized accumulator";
    678     data_->IncrRefCount();
    679   }
    680 
    681   ~ReplaceAccumulator() {
    682      if (!data_->DecrRefCount())
    683       delete data_;
    684   }
    685 
    686   // Does not take ownership of the state table, the state table
    687   // is own by the ReplaceFst
    688   void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples,
    689             const StateTable *state_table) {
    690     init_ = true;
    691     data_->Init(fst_tuples, state_table);
    692   }
    693 
    694   void SetState(StateId s) {
    695     if (!init_) {
    696       FSTERROR() << "ReplaceAccumulator::SetState: incorrectly initialized.";
    697       error_ = true;
    698       return;
    699     }
    700     StateTuple tuple = data_->GetTuple(s);
    701     fst_id_ = tuple.fst_id - 1;  // Replace FST ID is 1-based
    702     data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state);
    703     if ((tuple.prefix_id != 0) &&
    704         (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) {
    705       offset_ = 1;
    706       offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state);
    707     } else {
    708       offset_ = 0;
    709       offset_weight_ = Weight::Zero();
    710     }
    711   }
    712 
    713   Weight Sum(Weight w, Weight v) {
    714     if (error_) return Weight::NoWeight();
    715     return data_->GetAccumulator(fst_id_)->Sum(w, v);
    716   }
    717 
    718   template <class ArcIterator>
    719   Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
    720              ssize_t end) {
    721     if (error_) return Weight::NoWeight();
    722     Weight sum = begin == end ? Weight::Zero()
    723         : data_->GetAccumulator(fst_id_)->Sum(
    724             w, aiter, begin ? begin - offset_ : 0, end - offset_);
    725     if (begin == 0 && end != 0 && offset_ > 0)
    726       sum = Sum(offset_weight_, sum);
    727     return sum;
    728   }
    729 
    730   bool Error() const { return error_; }
    731 
    732  private:
    733   bool init_;
    734   ReplaceAccumulatorData<Accumulator, T> *data_;
    735   Label fst_id_;
    736   size_t offset_;
    737   Weight offset_weight_;
    738   bool error_;
    739 
    740   void operator=(const ReplaceAccumulator<Accumulator, T> &);   // Disallow
    741 };
    742 
    743 }  // namespace fst
    744 
    745 #endif  // FST_LIB_ACCUMULATOR_H__
    746