Home | History | Annotate | Download | only in lib
      1 // rmepsilon.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen)
     16 //
     17 // \file
     18 // Functions and classes that implemement epsilon-removal.
     19 
     20 #ifndef FST_LIB_RMEPSILON_H__
     21 #define FST_LIB_RMEPSILON_H__
     22 
     23 #include <unordered_map>
     24 #include <forward_list>
     25 
     26 #include "fst/lib/arcfilter.h"
     27 #include "fst/lib/cache.h"
     28 #include "fst/lib/connect.h"
     29 #include "fst/lib/factor-weight.h"
     30 #include "fst/lib/invert.h"
     31 #include "fst/lib/map.h"
     32 #include "fst/lib/queue.h"
     33 #include "fst/lib/shortest-distance.h"
     34 #include "fst/lib/topsort.h"
     35 
     36 namespace fst {
     37 
     38 template <class Arc, class Queue>
     39 struct RmEpsilonOptions
     40     : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > {
     41   typedef typename Arc::StateId StateId;
     42 
     43   bool connect;  // Connect output
     44 
     45   RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true)
     46       : ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> >(
     47           q, EpsilonArcFilter<Arc>(), kNoStateId, d), connect(c) {}
     48 
     49 };
     50 
     51 
     52 // Computation state of the epsilon-removal algorithm.
     53 template <class Arc, class Queue>
     54 class RmEpsilonState {
     55  public:
     56   typedef typename Arc::Label Label;
     57   typedef typename Arc::StateId StateId;
     58   typedef typename Arc::Weight Weight;
     59 
     60   RmEpsilonState(const Fst<Arc> &fst,
     61                  vector<Weight> *distance,
     62                  const RmEpsilonOptions<Arc, Queue> &opts)
     63       : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true) {
     64   }
     65 
     66   // Compute arcs and final weight for state 's'
     67   void Expand(StateId s);
     68 
     69   // Returns arcs of expanded state.
     70   vector<Arc> &Arcs() { return arcs_; }
     71 
     72   // Returns final weight of expanded state.
     73   const Weight &Final() const { return final_; }
     74 
     75  private:
     76   struct Element {
     77     Label ilabel;
     78     Label olabel;
     79     StateId nextstate;
     80 
     81     Element() {}
     82 
     83     Element(Label i, Label o, StateId s)
     84         : ilabel(i), olabel(o), nextstate(s) {}
     85   };
     86 
     87   class ElementKey {
     88    public:
     89     size_t operator()(const Element& e) const {
     90       return static_cast<size_t>(e.nextstate);
     91       return static_cast<size_t>(e.nextstate +
     92                                  e.ilabel * kPrime0 +
     93                                  e.olabel * kPrime1);
     94     }
     95 
     96    private:
     97     static const int kPrime0 = 7853;
     98     static const int kPrime1 = 7867;
     99   };
    100 
    101   class ElementEqual {
    102    public:
    103     bool operator()(const Element &e1, const Element &e2) const {
    104       return (e1.ilabel == e2.ilabel) &&  (e1.olabel == e2.olabel)
    105                          && (e1.nextstate == e2.nextstate);
    106     }
    107   };
    108 
    109  private:
    110   typedef std::unordered_map<Element, pair<StateId, ssize_t>,
    111                              ElementKey, ElementEqual> ElementMap;
    112 
    113   const Fst<Arc> &fst_;
    114   // Distance from state being expanded in epsilon-closure.
    115   vector<Weight> *distance_;
    116   // Shortest distance algorithm computation state.
    117   ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_;
    118   // Maps an element 'e' to a pair 'p' corresponding to a position
    119   // in the arcs vector of the state being expanded. 'e' corresponds
    120   // to the position 'p.second' in the 'arcs_' vector if 'p.first' is
    121   // equal to the state being expanded.
    122   ElementMap element_map_;
    123   EpsilonArcFilter<Arc> eps_filter_;
    124   stack<StateId> eps_queue_;      // Queue used to visit the epsilon-closure
    125   vector<bool> visited_;          // '[i] = true' if state 'i' has been visited
    126   std::forward_list<StateId> visited_states_; // List of visited states
    127   vector<Arc> arcs_;              // Arcs of state being expanded
    128   Weight final_;                  // Final weight of state being expanded
    129 
    130   void operator=(const RmEpsilonState);  // Disallow
    131 };
    132 
    133 
    134 template <class Arc, class Queue>
    135 void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) {
    136    sd_state_.ShortestDistance(source);
    137    eps_queue_.push(source);
    138    final_ = Weight::Zero();
    139    arcs_.clear();
    140 
    141    while (!eps_queue_.empty()) {
    142      StateId state = eps_queue_.top();
    143      eps_queue_.pop();
    144 
    145      while ((StateId)visited_.size() <= state) visited_.push_back(false);
    146      visited_[state] = true;
    147      visited_states_.push_front(state);
    148 
    149      for (ArcIterator< Fst<Arc> > ait(fst_, state);
    150           !ait.Done();
    151           ait.Next()) {
    152        Arc arc = ait.Value();
    153        arc.weight = Times((*distance_)[state], arc.weight);
    154 
    155        if (eps_filter_(arc)) {
    156          while ((StateId)visited_.size() <= arc.nextstate)
    157            visited_.push_back(false);
    158          if (!visited_[arc.nextstate])
    159            eps_queue_.push(arc.nextstate);
    160        } else {
    161           Element element(arc.ilabel, arc.olabel, arc.nextstate);
    162           typename ElementMap::iterator it = element_map_.find(element);
    163           if (it == element_map_.end()) {
    164             element_map_.insert(
    165                 pair<Element, pair<StateId, ssize_t> >
    166                 (element, pair<StateId, ssize_t>(source, arcs_.size())));
    167             arcs_.push_back(arc);
    168           } else {
    169             if (((*it).second).first == source) {
    170               Weight &w = arcs_[((*it).second).second].weight;
    171               w = Plus(w, arc.weight);
    172             } else {
    173               ((*it).second).first = source;
    174               ((*it).second).second = arcs_.size();
    175               arcs_.push_back(arc);
    176             }
    177           }
    178         }
    179      }
    180      final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
    181    }
    182 
    183    while (!visited_states_.empty()) {
    184      visited_[visited_states_.front()] = false;
    185      visited_states_.pop_front();
    186    }
    187 }
    188 
    189 
    190 // Removes epsilon-transitions (when both the input and output label
    191 // are an epsilon) from a transducer. The result will be an equivalent
    192 // FST that has no such epsilon transitions.  This version modifies
    193 // its input. It allows fine control via the options argument; see
    194 // below for a simpler interface.
    195 //
    196 // The vector 'distance' will be used to hold the shortest distances
    197 // during the epsilon-closure computation. The state queue discipline
    198 // and convergence delta are taken in the options argument.
    199 template <class Arc, class Queue>
    200 void RmEpsilon(MutableFst<Arc> *fst,
    201                vector<typename Arc::Weight> *distance,
    202                const RmEpsilonOptions<Arc, Queue> &opts) {
    203   typedef typename Arc::StateId StateId;
    204   typedef typename Arc::Weight Weight;
    205   typedef typename Arc::Label Label;
    206 
    207   // States sorted in topological order when (acyclic) or generic
    208   // topological order (cyclic).
    209   vector<StateId> states;
    210 
    211   if (fst->Properties(kTopSorted, false) & kTopSorted) {
    212     for (StateId i = 0; i < (StateId)fst->NumStates(); i++)
    213       states.push_back(i);
    214   } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
    215     vector<StateId> order;
    216     bool acyclic;
    217     TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
    218     DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
    219     if (!acyclic)
    220       LOG(FATAL) << "RmEpsilon: not acyclic though property bit is set";
    221     states.resize(order.size());
    222     for (StateId i = 0; i < (StateId)order.size(); i++)
    223       states[order[i]] = i;
    224   } else {
    225      uint64 props;
    226      vector<StateId> scc;
    227      SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props);
    228      DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
    229      vector<StateId> first(scc.size(), kNoStateId);
    230      vector<StateId> next(scc.size(), kNoStateId);
    231      for (StateId i = 0; i < (StateId)scc.size(); i++) {
    232        if (first[scc[i]] != kNoStateId)
    233          next[i] = first[scc[i]];
    234        first[scc[i]] = i;
    235      }
    236      for (StateId i = 0; i < (StateId)first.size(); i++)
    237        for (StateId j = first[i]; j != kNoStateId; j = next[j])
    238          states.push_back(j);
    239   }
    240 
    241   RmEpsilonState<Arc, Queue>
    242     rmeps_state(*fst, distance, opts);
    243 
    244   while (!states.empty()) {
    245     StateId state = states.back();
    246     states.pop_back();
    247     rmeps_state.Expand(state);
    248     fst->SetFinal(state, rmeps_state.Final());
    249     fst->DeleteArcs(state);
    250     vector<Arc> &arcs = rmeps_state.Arcs();
    251     while (!arcs.empty()) {
    252       fst->AddArc(state, arcs.back());
    253       arcs.pop_back();
    254     }
    255   }
    256 
    257   fst->SetProperties(RmEpsilonProperties(
    258                          fst->Properties(kFstProperties, false)),
    259                      kFstProperties);
    260 
    261   if (opts.connect)
    262     Connect(fst);
    263 }
    264 
    265 
    266 // Removes epsilon-transitions (when both the input and output label
    267 // are an epsilon) from a transducer. The result will be an equivalent
    268 // FST that has no such epsilon transitions. This version modifies its
    269 // input. It has a simplified interface; see above for a version that
    270 // allows finer control.
    271 //
    272 // Complexity:
    273 // - Time:
    274 //   - Unweighted: O(V2 + V E)
    275 //   - Acyclic: O(V2 + V E)
    276 //   - Tropical semiring: O(V2 log V + V E)
    277 //   - General: exponential
    278 // - Space: O(V E)
    279 // where V = # of states visited, E = # of arcs.
    280 //
    281 // References:
    282 // - Mehryar Mohri. Generic Epsilon-Removal and Input
    283 //   Epsilon-Normalization Algorithms for Weighted Transducers,
    284 //   "International Journal of Computer Science", 13(1):129-143 (2002).
    285 template <class Arc>
    286 void RmEpsilon(MutableFst<Arc> *fst, bool connect = true) {
    287   typedef typename Arc::StateId StateId;
    288   typedef typename Arc::Weight Weight;
    289   typedef typename Arc::Label Label;
    290 
    291   vector<Weight> distance;
    292   AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
    293   RmEpsilonOptions<Arc, AutoQueue<StateId> >
    294     opts(&state_queue, kDelta, connect);
    295 
    296   RmEpsilon(fst, &distance, opts);
    297 }
    298 
    299 
    300 struct RmEpsilonFstOptions : CacheOptions {
    301   float delta;
    302 
    303   RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
    304       : CacheOptions(opts), delta(delta) {}
    305 
    306   explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
    307 };
    308 
    309 
    310 // Implementation of delayed RmEpsilonFst.
    311 template <class A>
    312 class RmEpsilonFstImpl : public CacheImpl<A> {
    313  public:
    314   using FstImpl<A>::SetType;
    315   using FstImpl<A>::SetProperties;
    316   using FstImpl<A>::Properties;
    317   using FstImpl<A>::SetInputSymbols;
    318   using FstImpl<A>::SetOutputSymbols;
    319 
    320   using CacheBaseImpl< CacheState<A> >::HasStart;
    321   using CacheBaseImpl< CacheState<A> >::HasFinal;
    322   using CacheBaseImpl< CacheState<A> >::HasArcs;
    323 
    324   typedef typename A::Label Label;
    325   typedef typename A::Weight Weight;
    326   typedef typename A::StateId StateId;
    327   typedef CacheState<A> State;
    328 
    329   RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts)
    330       : CacheImpl<A>(opts),
    331         fst_(fst.Copy()),
    332         rmeps_state_(
    333             *fst_,
    334             &distance_,
    335             RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, opts.delta,
    336                                                      false)
    337             ) {
    338     SetType("rmepsilon");
    339     uint64 props = fst.Properties(kFstProperties, false);
    340     SetProperties(RmEpsilonProperties(props, true), kCopyProperties);
    341   }
    342 
    343   ~RmEpsilonFstImpl() {
    344     delete fst_;
    345   }
    346 
    347   StateId Start() {
    348     if (!HasStart()) {
    349       SetStart(fst_->Start());
    350     }
    351     return CacheImpl<A>::Start();
    352   }
    353 
    354   Weight Final(StateId s) {
    355     if (!HasFinal(s)) {
    356       Expand(s);
    357     }
    358     return CacheImpl<A>::Final(s);
    359   }
    360 
    361   size_t NumArcs(StateId s) {
    362     if (!HasArcs(s))
    363       Expand(s);
    364     return CacheImpl<A>::NumArcs(s);
    365   }
    366 
    367   size_t NumInputEpsilons(StateId s) {
    368     if (!HasArcs(s))
    369       Expand(s);
    370     return CacheImpl<A>::NumInputEpsilons(s);
    371   }
    372 
    373   size_t NumOutputEpsilons(StateId s) {
    374     if (!HasArcs(s))
    375       Expand(s);
    376     return CacheImpl<A>::NumOutputEpsilons(s);
    377   }
    378 
    379   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    380     if (!HasArcs(s))
    381       Expand(s);
    382     CacheImpl<A>::InitArcIterator(s, data);
    383   }
    384 
    385   void Expand(StateId s) {
    386     rmeps_state_.Expand(s);
    387     SetFinal(s, rmeps_state_.Final());
    388     vector<A> &arcs = rmeps_state_.Arcs();
    389     while (!arcs.empty()) {
    390       AddArc(s, arcs.back());
    391       arcs.pop_back();
    392     }
    393     SetArcs(s);
    394   }
    395 
    396  private:
    397   const Fst<A> *fst_;
    398   vector<Weight> distance_;
    399   FifoQueue<StateId> queue_;
    400   RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_;
    401 
    402   DISALLOW_EVIL_CONSTRUCTORS(RmEpsilonFstImpl);
    403 };
    404 
    405 
    406 // Removes epsilon-transitions (when both the input and output label
    407 // are an epsilon) from a transducer. The result will be an equivalent
    408 // FST that has no such epsilon transitions.  This version is a
    409 // delayed Fst.
    410 //
    411 // Complexity:
    412 // - Time:
    413 //   - Unweighted: O(v^2 + v e)
    414 //   - General: exponential
    415 // - Space: O(v e)
    416 // where v = # of states visited, e = # of arcs visited. Constant time
    417 // to visit an input state or arc is assumed and exclusive of caching.
    418 //
    419 // References:
    420 // - Mehryar Mohri. Generic Epsilon-Removal and Input
    421 //   Epsilon-Normalization Algorithms for Weighted Transducers,
    422 //   "International Journal of Computer Science", 13(1):129-143 (2002).
    423 template <class A>
    424 class RmEpsilonFst : public Fst<A> {
    425  public:
    426   friend class ArcIterator< RmEpsilonFst<A> >;
    427   friend class CacheStateIterator< RmEpsilonFst<A> >;
    428   friend class CacheArcIterator< RmEpsilonFst<A> >;
    429 
    430   typedef A Arc;
    431   typedef typename A::Weight Weight;
    432   typedef typename A::StateId StateId;
    433   typedef CacheState<A> State;
    434 
    435   RmEpsilonFst(const Fst<A> &fst)
    436       : impl_(new RmEpsilonFstImpl<A>(fst, RmEpsilonFstOptions())) {}
    437 
    438   RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
    439       : impl_(new RmEpsilonFstImpl<A>(fst, opts)) {}
    440 
    441   explicit RmEpsilonFst(const RmEpsilonFst<A> &fst) : impl_(fst.impl_) {
    442     impl_->IncrRefCount();
    443   }
    444 
    445   virtual ~RmEpsilonFst() { if (!impl_->DecrRefCount()) delete impl_;  }
    446 
    447   virtual StateId Start() const { return impl_->Start(); }
    448 
    449   virtual Weight Final(StateId s) const { return impl_->Final(s); }
    450 
    451   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
    452 
    453   virtual size_t NumInputEpsilons(StateId s) const {
    454     return impl_->NumInputEpsilons(s);
    455   }
    456 
    457   virtual size_t NumOutputEpsilons(StateId s) const {
    458     return impl_->NumOutputEpsilons(s);
    459   }
    460 
    461   virtual uint64 Properties(uint64 mask, bool test) const {
    462     if (test) {
    463       uint64 known, test = TestProperties(*this, mask, &known);
    464       impl_->SetProperties(test, known);
    465       return test & mask;
    466     } else {
    467       return impl_->Properties(mask);
    468     }
    469   }
    470 
    471   virtual const string& Type() const { return impl_->Type(); }
    472 
    473   virtual RmEpsilonFst<A> *Copy() const {
    474     return new RmEpsilonFst<A>(*this);
    475   }
    476 
    477   virtual const SymbolTable* InputSymbols() const {
    478     return impl_->InputSymbols();
    479   }
    480 
    481   virtual const SymbolTable* OutputSymbols() const {
    482     return impl_->OutputSymbols();
    483   }
    484 
    485   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
    486 
    487   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    488     impl_->InitArcIterator(s, data);
    489   }
    490 
    491  protected:
    492   RmEpsilonFstImpl<A> *Impl() { return impl_; }
    493 
    494  private:
    495   RmEpsilonFstImpl<A> *impl_;
    496 
    497   void operator=(const RmEpsilonFst<A> &fst);  // disallow
    498 };
    499 
    500 
    501 // Specialization for RmEpsilonFst.
    502 template<class A>
    503 class StateIterator< RmEpsilonFst<A> >
    504     : public CacheStateIterator< RmEpsilonFst<A> > {
    505  public:
    506   explicit StateIterator(const RmEpsilonFst<A> &fst)
    507       : CacheStateIterator< RmEpsilonFst<A> >(fst) {}
    508 };
    509 
    510 
    511 // Specialization for RmEpsilonFst.
    512 template <class A>
    513 class ArcIterator< RmEpsilonFst<A> >
    514     : public CacheArcIterator< RmEpsilonFst<A> > {
    515  public:
    516   typedef typename A::StateId StateId;
    517 
    518   ArcIterator(const RmEpsilonFst<A> &fst, StateId s)
    519       : CacheArcIterator< RmEpsilonFst<A> >(fst, s) {
    520     if (!fst.impl_->HasArcs(s))
    521       fst.impl_->Expand(s);
    522   }
    523 
    524  private:
    525   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
    526 };
    527 
    528 
    529 template <class A> inline
    530 void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
    531   data->base = new StateIterator< RmEpsilonFst<A> >(*this);
    532 }
    533 
    534 
    535 // Useful alias when using StdArc.
    536 typedef RmEpsilonFst<StdArc> StdRmEpsilonFst;
    537 
    538 }  // namespace fst
    539 
    540 #endif  // FST_LIB_RMEPSILON_H__
    541