Home | History | Annotate | Download | only in fst
      1 // rmepsilon.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: allauzen (at) google.com (Cyril Allauzen)
     17 //
     18 // \file
     19 // Functions and classes that implemement epsilon-removal.
     20 
     21 #ifndef FST_LIB_RMEPSILON_H__
     22 #define FST_LIB_RMEPSILON_H__
     23 
     24 #include <tr1/unordered_map>
     25 using std::tr1::unordered_map;
     26 using std::tr1::unordered_multimap;
     27 #include <fst/slist.h>
     28 #include <stack>
     29 #include <string>
     30 #include <utility>
     31 using std::pair; using std::make_pair;
     32 #include <vector>
     33 using std::vector;
     34 
     35 #include <fst/arcfilter.h>
     36 #include <fst/cache.h>
     37 #include <fst/connect.h>
     38 #include <fst/factor-weight.h>
     39 #include <fst/invert.h>
     40 #include <fst/prune.h>
     41 #include <fst/queue.h>
     42 #include <fst/shortest-distance.h>
     43 #include <fst/topsort.h>
     44 
     45 
     46 namespace fst {
     47 
     48 template <class Arc, class Queue>
     49 class RmEpsilonOptions
     50     : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > {
     51  public:
     52   typedef typename Arc::StateId StateId;
     53   typedef typename Arc::Weight Weight;
     54 
     55   bool connect;              // Connect output
     56   Weight weight_threshold;   // Pruning weight threshold.
     57   StateId state_threshold;   // Pruning state threshold.
     58 
     59   explicit RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true,
     60                             Weight w = Weight::Zero(),
     61                             StateId n = kNoStateId)
     62       : ShortestDistanceOptions< Arc, Queue, EpsilonArcFilter<Arc> >(
     63           q, EpsilonArcFilter<Arc>(), kNoStateId, d),
     64         connect(c), weight_threshold(w), state_threshold(n) {}
     65  private:
     66   RmEpsilonOptions();  // disallow
     67 };
     68 
     69 // Computation state of the epsilon-removal algorithm.
     70 template <class Arc, class Queue>
     71 class RmEpsilonState {
     72  public:
     73   typedef typename Arc::Label Label;
     74   typedef typename Arc::StateId StateId;
     75   typedef typename Arc::Weight Weight;
     76 
     77   RmEpsilonState(const Fst<Arc> &fst,
     78                  vector<Weight> *distance,
     79                  const RmEpsilonOptions<Arc, Queue> &opts)
     80       : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true),
     81         expand_id_(0) {}
     82 
     83   // Compute arcs and final weight for state 's'
     84   void Expand(StateId s);
     85 
     86   // Returns arcs of expanded state.
     87   vector<Arc> &Arcs() { return arcs_; }
     88 
     89   // Returns final weight of expanded state.
     90   const Weight &Final() const { return final_; }
     91 
     92   // Return true if an error has occured.
     93   bool Error() const { return sd_state_.Error(); }
     94 
     95  private:
     96   static const size_t kPrime0 = 7853;
     97   static const size_t kPrime1 = 7867;
     98 
     99   struct Element {
    100     Label ilabel;
    101     Label olabel;
    102     StateId nextstate;
    103 
    104     Element() {}
    105 
    106     Element(Label i, Label o, StateId s)
    107         : ilabel(i), olabel(o), nextstate(s) {}
    108   };
    109 
    110   class ElementKey {
    111    public:
    112     size_t operator()(const Element& e) const {
    113       return static_cast<size_t>(e.nextstate +
    114                                  e.ilabel * kPrime0 +
    115                                  e.olabel * kPrime1);
    116     }
    117 
    118    private:
    119   };
    120 
    121   class ElementEqual {
    122    public:
    123     bool operator()(const Element &e1, const Element &e2) const {
    124       return (e1.ilabel == e2.ilabel) &&  (e1.olabel == e2.olabel)
    125                          && (e1.nextstate == e2.nextstate);
    126     }
    127   };
    128 
    129   typedef unordered_map<Element, pair<StateId, size_t>,
    130                    ElementKey, ElementEqual> ElementMap;
    131 
    132   const Fst<Arc> &fst_;
    133   // Distance from state being expanded in epsilon-closure.
    134   vector<Weight> *distance_;
    135   // Shortest distance algorithm computation state.
    136   ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_;
    137   // Maps an element 'e' to a pair 'p' corresponding to a position
    138   // in the arcs vector of the state being expanded. 'e' corresponds
    139   // to the position 'p.second' in the 'arcs_' vector if 'p.first' is
    140   // equal to the state being expanded.
    141   ElementMap element_map_;
    142   EpsilonArcFilter<Arc> eps_filter_;
    143   stack<StateId> eps_queue_;      // Queue used to visit the epsilon-closure
    144   vector<bool> visited_;          // '[i] = true' if state 'i' has been visited
    145   slist<StateId> visited_states_; // List of visited states
    146   vector<Arc> arcs_;              // Arcs of state being expanded
    147   Weight final_;                  // Final weight of state being expanded
    148   StateId expand_id_;             // Unique ID for each call to Expand
    149 
    150   DISALLOW_COPY_AND_ASSIGN(RmEpsilonState);
    151 };
    152 
    153 template <class Arc, class Queue>
    154 const size_t RmEpsilonState<Arc, Queue>::kPrime0;
    155 template <class Arc, class Queue>
    156 const size_t RmEpsilonState<Arc, Queue>::kPrime1;
    157 
    158 
    159 template <class Arc, class Queue>
    160 void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) {
    161    final_ = Weight::Zero();
    162    arcs_.clear();
    163    sd_state_.ShortestDistance(source);
    164    if (sd_state_.Error())
    165      return;
    166    eps_queue_.push(source);
    167 
    168    while (!eps_queue_.empty()) {
    169      StateId state = eps_queue_.top();
    170      eps_queue_.pop();
    171 
    172      while (visited_.size() <= state) visited_.push_back(false);
    173      if (visited_[state]) continue;
    174      visited_[state] = true;
    175      visited_states_.push_front(state);
    176 
    177      for (ArcIterator< Fst<Arc> > ait(fst_, state);
    178           !ait.Done();
    179           ait.Next()) {
    180        Arc arc = ait.Value();
    181        arc.weight = Times((*distance_)[state], arc.weight);
    182 
    183        if (eps_filter_(arc)) {
    184          while (visited_.size() <= arc.nextstate)
    185            visited_.push_back(false);
    186          if (!visited_[arc.nextstate])
    187            eps_queue_.push(arc.nextstate);
    188        } else {
    189           Element element(arc.ilabel, arc.olabel, arc.nextstate);
    190           typename ElementMap::iterator it = element_map_.find(element);
    191           if (it == element_map_.end()) {
    192             element_map_.insert(
    193                 pair<Element, pair<StateId, size_t> >
    194                 (element, pair<StateId, size_t>(expand_id_, arcs_.size())));
    195             arcs_.push_back(arc);
    196           } else {
    197             if (((*it).second).first == expand_id_) {
    198               Weight &w = arcs_[((*it).second).second].weight;
    199               w = Plus(w, arc.weight);
    200             } else {
    201               ((*it).second).first = expand_id_;
    202               ((*it).second).second = arcs_.size();
    203               arcs_.push_back(arc);
    204             }
    205           }
    206         }
    207      }
    208      final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
    209    }
    210 
    211    while (!visited_states_.empty()) {
    212      visited_[visited_states_.front()] = false;
    213      visited_states_.pop_front();
    214    }
    215    ++expand_id_;
    216 }
    217 
    218 // Removes epsilon-transitions (when both the input and output label
    219 // are an epsilon) from a transducer. The result will be an equivalent
    220 // FST that has no such epsilon transitions.  This version modifies
    221 // its input. It allows fine control via the options argument; see
    222 // below for a simpler interface.
    223 //
    224 // The vector 'distance' will be used to hold the shortest distances
    225 // during the epsilon-closure computation. The state queue discipline
    226 // and convergence delta are taken in the options argument.
    227 template <class Arc, class Queue>
    228 void RmEpsilon(MutableFst<Arc> *fst,
    229                vector<typename Arc::Weight> *distance,
    230                const RmEpsilonOptions<Arc, Queue> &opts) {
    231   typedef typename Arc::StateId StateId;
    232   typedef typename Arc::Weight Weight;
    233   typedef typename Arc::Label Label;
    234 
    235   if (fst->Start() == kNoStateId) {
    236     return;
    237   }
    238 
    239   // 'noneps_in[s]' will be set to true iff 's' admits a non-epsilon
    240   // incoming transition or is the start state.
    241   vector<bool> noneps_in(fst->NumStates(), false);
    242   noneps_in[fst->Start()] = true;
    243   for (StateId i = 0; i < fst->NumStates(); ++i) {
    244     for (ArcIterator<Fst<Arc> > aiter(*fst, i);
    245          !aiter.Done();
    246          aiter.Next()) {
    247       if (aiter.Value().ilabel != 0 || aiter.Value().olabel != 0)
    248         noneps_in[aiter.Value().nextstate] = true;
    249     }
    250   }
    251 
    252   // States sorted in topological order when (acyclic) or generic
    253   // topological order (cyclic).
    254   vector<StateId> states;
    255   states.reserve(fst->NumStates());
    256 
    257   if (fst->Properties(kTopSorted, false) & kTopSorted) {
    258     for (StateId i = 0; i < fst->NumStates(); i++)
    259       states.push_back(i);
    260   } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
    261     vector<StateId> order;
    262     bool acyclic;
    263     TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
    264     DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
    265     // Sanity check: should be acyclic if property bit is set.
    266     if(!acyclic) {
    267       FSTERROR() << "RmEpsilon: inconsistent acyclic property bit";
    268       fst->SetProperties(kError, kError);
    269       return;
    270     }
    271     states.resize(order.size());
    272     for (StateId i = 0; i < order.size(); i++)
    273       states[order[i]] = i;
    274   } else {
    275      uint64 props;
    276      vector<StateId> scc;
    277      SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props);
    278      DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
    279      vector<StateId> first(scc.size(), kNoStateId);
    280      vector<StateId> next(scc.size(), kNoStateId);
    281      for (StateId i = 0; i < scc.size(); i++) {
    282        if (first[scc[i]] != kNoStateId)
    283          next[i] = first[scc[i]];
    284        first[scc[i]] = i;
    285      }
    286      for (StateId i = 0; i < first.size(); i++)
    287        for (StateId j = first[i]; j != kNoStateId; j = next[j])
    288          states.push_back(j);
    289   }
    290 
    291   RmEpsilonState<Arc, Queue>
    292     rmeps_state(*fst, distance, opts);
    293 
    294   while (!states.empty()) {
    295     StateId state = states.back();
    296     states.pop_back();
    297     if (!noneps_in[state])
    298       continue;
    299     rmeps_state.Expand(state);
    300     fst->SetFinal(state, rmeps_state.Final());
    301     fst->DeleteArcs(state);
    302     vector<Arc> &arcs = rmeps_state.Arcs();
    303     fst->ReserveArcs(state, arcs.size());
    304     while (!arcs.empty()) {
    305       fst->AddArc(state, arcs.back());
    306       arcs.pop_back();
    307     }
    308   }
    309 
    310   for (StateId s = 0; s < fst->NumStates(); ++s) {
    311     if (!noneps_in[s])
    312       fst->DeleteArcs(s);
    313   }
    314 
    315   if(rmeps_state.Error())
    316     fst->SetProperties(kError, kError);
    317   fst->SetProperties(
    318       RmEpsilonProperties(fst->Properties(kFstProperties, false)),
    319       kFstProperties);
    320 
    321   if (opts.weight_threshold != Weight::Zero() ||
    322       opts.state_threshold != kNoStateId)
    323     Prune(fst, opts.weight_threshold, opts.state_threshold);
    324   if (opts.connect && (opts.weight_threshold == Weight::Zero() ||
    325                        opts.state_threshold != kNoStateId))
    326     Connect(fst);
    327 }
    328 
    329 // Removes epsilon-transitions (when both the input and output label
    330 // are an epsilon) from a transducer. The result will be an equivalent
    331 // FST that has no such epsilon transitions. This version modifies its
    332 // input. It has a simplified interface; see above for a version that
    333 // allows finer control.
    334 //
    335 // Complexity:
    336 // - Time:
    337 //   - Unweighted: O(V2 + V E)
    338 //   - Acyclic: O(V2 + V E)
    339 //   - Tropical semiring: O(V2 log V + V E)
    340 //   - General: exponential
    341 // - Space: O(V E)
    342 // where V = # of states visited, E = # of arcs.
    343 //
    344 // References:
    345 // - Mehryar Mohri. Generic Epsilon-Removal and Input
    346 //   Epsilon-Normalization Algorithms for Weighted Transducers,
    347 //   "International Journal of Computer Science", 13(1):129-143 (2002).
    348 template <class Arc>
    349 void RmEpsilon(MutableFst<Arc> *fst,
    350                bool connect = true,
    351                typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
    352                typename Arc::StateId state_threshold = kNoStateId,
    353                float delta = kDelta) {
    354   typedef typename Arc::StateId StateId;
    355   typedef typename Arc::Weight Weight;
    356   typedef typename Arc::Label Label;
    357 
    358   vector<Weight> distance;
    359   AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
    360   RmEpsilonOptions<Arc, AutoQueue<StateId> >
    361       opts(&state_queue, delta, connect, weight_threshold, state_threshold);
    362 
    363   RmEpsilon(fst, &distance, opts);
    364 }
    365 
    366 
    367 struct RmEpsilonFstOptions : CacheOptions {
    368   float delta;
    369 
    370   RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
    371       : CacheOptions(opts), delta(delta) {}
    372 
    373   explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
    374 };
    375 
    376 
    377 // Implementation of delayed RmEpsilonFst.
    378 template <class A>
    379 class RmEpsilonFstImpl : public CacheImpl<A> {
    380  public:
    381   using FstImpl<A>::SetType;
    382   using FstImpl<A>::SetProperties;
    383   using FstImpl<A>::SetInputSymbols;
    384   using FstImpl<A>::SetOutputSymbols;
    385 
    386   using CacheBaseImpl< CacheState<A> >::PushArc;
    387   using CacheBaseImpl< CacheState<A> >::HasArcs;
    388   using CacheBaseImpl< CacheState<A> >::HasFinal;
    389   using CacheBaseImpl< CacheState<A> >::HasStart;
    390   using CacheBaseImpl< CacheState<A> >::SetArcs;
    391   using CacheBaseImpl< CacheState<A> >::SetFinal;
    392   using CacheBaseImpl< CacheState<A> >::SetStart;
    393 
    394   typedef typename A::Label Label;
    395   typedef typename A::Weight Weight;
    396   typedef typename A::StateId StateId;
    397   typedef CacheState<A> State;
    398 
    399   RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts)
    400       : CacheImpl<A>(opts),
    401         fst_(fst.Copy()),
    402         delta_(opts.delta),
    403         rmeps_state_(
    404             *fst_,
    405             &distance_,
    406             RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) {
    407     SetType("rmepsilon");
    408     uint64 props = fst.Properties(kFstProperties, false);
    409     SetProperties(RmEpsilonProperties(props, true), kCopyProperties);
    410     SetInputSymbols(fst.InputSymbols());
    411     SetOutputSymbols(fst.OutputSymbols());
    412   }
    413 
    414   RmEpsilonFstImpl(const RmEpsilonFstImpl &impl)
    415       : CacheImpl<A>(impl),
    416         fst_(impl.fst_->Copy(true)),
    417         delta_(impl.delta_),
    418         rmeps_state_(
    419             *fst_,
    420             &distance_,
    421             RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) {
    422     SetType("rmepsilon");
    423     SetProperties(impl.Properties(), kCopyProperties);
    424     SetInputSymbols(impl.InputSymbols());
    425     SetOutputSymbols(impl.OutputSymbols());
    426   }
    427 
    428   ~RmEpsilonFstImpl() {
    429     delete fst_;
    430   }
    431 
    432   StateId Start() {
    433     if (!HasStart()) {
    434       SetStart(fst_->Start());
    435     }
    436     return CacheImpl<A>::Start();
    437   }
    438 
    439   Weight Final(StateId s) {
    440     if (!HasFinal(s)) {
    441       Expand(s);
    442     }
    443     return CacheImpl<A>::Final(s);
    444   }
    445 
    446   size_t NumArcs(StateId s) {
    447     if (!HasArcs(s))
    448       Expand(s);
    449     return CacheImpl<A>::NumArcs(s);
    450   }
    451 
    452   size_t NumInputEpsilons(StateId s) {
    453     if (!HasArcs(s))
    454       Expand(s);
    455     return CacheImpl<A>::NumInputEpsilons(s);
    456   }
    457 
    458   size_t NumOutputEpsilons(StateId s) {
    459     if (!HasArcs(s))
    460       Expand(s);
    461     return CacheImpl<A>::NumOutputEpsilons(s);
    462   }
    463 
    464   uint64 Properties() const { return Properties(kFstProperties); }
    465 
    466   // Set error if found; return FST impl properties.
    467   uint64 Properties(uint64 mask) const {
    468     if ((mask & kError) &&
    469         (fst_->Properties(kError, false) || rmeps_state_.Error()))
    470       SetProperties(kError, kError);
    471     return FstImpl<A>::Properties(mask);
    472   }
    473 
    474   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    475     if (!HasArcs(s))
    476       Expand(s);
    477     CacheImpl<A>::InitArcIterator(s, data);
    478   }
    479 
    480   void Expand(StateId s) {
    481     rmeps_state_.Expand(s);
    482     SetFinal(s, rmeps_state_.Final());
    483     vector<A> &arcs = rmeps_state_.Arcs();
    484     while (!arcs.empty()) {
    485       PushArc(s, arcs.back());
    486       arcs.pop_back();
    487     }
    488     SetArcs(s);
    489   }
    490 
    491  private:
    492   const Fst<A> *fst_;
    493   float delta_;
    494   vector<Weight> distance_;
    495   FifoQueue<StateId> queue_;
    496   RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_;
    497 
    498   void operator=(const RmEpsilonFstImpl<A> &);  // disallow
    499 };
    500 
    501 
    502 // Removes epsilon-transitions (when both the input and output label
    503 // are an epsilon) from a transducer. The result will be an equivalent
    504 // FST that has no such epsilon transitions.  This version is a
    505 // delayed Fst.
    506 //
    507 // Complexity:
    508 // - Time:
    509 //   - Unweighted: O(v^2 + v e)
    510 //   - General: exponential
    511 // - Space: O(v e)
    512 // where v = # of states visited, e = # of arcs visited. Constant time
    513 // to visit an input state or arc is assumed and exclusive of caching.
    514 //
    515 // References:
    516 // - Mehryar Mohri. Generic Epsilon-Removal and Input
    517 //   Epsilon-Normalization Algorithms for Weighted Transducers,
    518 //   "International Journal of Computer Science", 13(1):129-143 (2002).
    519 //
    520 // This class attaches interface to implementation and handles
    521 // reference counting, delegating most methods to ImplToFst.
    522 template <class A>
    523 class RmEpsilonFst : public ImplToFst< RmEpsilonFstImpl<A> > {
    524  public:
    525   friend class ArcIterator< RmEpsilonFst<A> >;
    526   friend class StateIterator< RmEpsilonFst<A> >;
    527 
    528   typedef A Arc;
    529   typedef typename A::StateId StateId;
    530   typedef CacheState<A> State;
    531   typedef RmEpsilonFstImpl<A> Impl;
    532 
    533   RmEpsilonFst(const Fst<A> &fst)
    534       : ImplToFst<Impl>(new Impl(fst, RmEpsilonFstOptions())) {}
    535 
    536   RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
    537       : ImplToFst<Impl>(new Impl(fst, opts)) {}
    538 
    539   // See Fst<>::Copy() for doc.
    540   RmEpsilonFst(const RmEpsilonFst<A> &fst, bool safe = false)
    541       : ImplToFst<Impl>(fst, safe) {}
    542 
    543   // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc.
    544   virtual RmEpsilonFst<A> *Copy(bool safe = false) const {
    545     return new RmEpsilonFst<A>(*this, safe);
    546   }
    547 
    548   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
    549 
    550   virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
    551     GetImpl()->InitArcIterator(s, data);
    552   }
    553 
    554  private:
    555   // Makes visible to friends.
    556   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
    557 
    558   void operator=(const RmEpsilonFst<A> &fst);  // disallow
    559 };
    560 
    561 // Specialization for RmEpsilonFst.
    562 template<class A>
    563 class StateIterator< RmEpsilonFst<A> >
    564     : public CacheStateIterator< RmEpsilonFst<A> > {
    565  public:
    566   explicit StateIterator(const RmEpsilonFst<A> &fst)
    567       : CacheStateIterator< RmEpsilonFst<A> >(fst, fst.GetImpl()) {}
    568 };
    569 
    570 
    571 // Specialization for RmEpsilonFst.
    572 template <class A>
    573 class ArcIterator< RmEpsilonFst<A> >
    574     : public CacheArcIterator< RmEpsilonFst<A> > {
    575  public:
    576   typedef typename A::StateId StateId;
    577 
    578   ArcIterator(const RmEpsilonFst<A> &fst, StateId s)
    579       : CacheArcIterator< RmEpsilonFst<A> >(fst.GetImpl(), s) {
    580     if (!fst.GetImpl()->HasArcs(s))
    581       fst.GetImpl()->Expand(s);
    582   }
    583 
    584  private:
    585   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
    586 };
    587 
    588 
    589 template <class A> inline
    590 void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
    591   data->base = new StateIterator< RmEpsilonFst<A> >(*this);
    592 }
    593 
    594 
    595 // Useful alias when using StdArc.
    596 typedef RmEpsilonFst<StdArc> StdRmEpsilonFst;
    597 
    598 }  // namespace fst
    599 
    600 #endif  // FST_LIB_RMEPSILON_H__
    601