Home | History | Annotate | Download | only in lib
      1 // rmepsilon.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen)
     16 //
     17 // \file
     18 // Functions and classes that implemement epsilon-removal.
     19 
     20 #ifndef FST_LIB_RMEPSILON_H__
     21 #define FST_LIB_RMEPSILON_H__
     22 
     23 #include <ext/hash_map>
     24 using __gnu_cxx::hash_map;
     25 #include <ext/slist>
     26 using __gnu_cxx::slist;
     27 
     28 #include "fst/lib/arcfilter.h"
     29 #include "fst/lib/cache.h"
     30 #include "fst/lib/connect.h"
     31 #include "fst/lib/factor-weight.h"
     32 #include "fst/lib/invert.h"
     33 #include "fst/lib/map.h"
     34 #include "fst/lib/queue.h"
     35 #include "fst/lib/shortest-distance.h"
     36 #include "fst/lib/topsort.h"
     37 
     38 namespace fst {
     39 
     40 template <class Arc, class Queue>
     41 struct RmEpsilonOptions
     42     : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > {
     43   typedef typename Arc::StateId StateId;
     44 
     45   bool connect;  // Connect output
     46 
     47   RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true)
     48       : ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> >(
     49           q, EpsilonArcFilter<Arc>(), kNoStateId, d), connect(c) {}
     50 
     51 };
     52 
     53 
     54 // Computation state of the epsilon-removal algorithm.
     55 template <class Arc, class Queue>
     56 class RmEpsilonState {
     57  public:
     58   typedef typename Arc::Label Label;
     59   typedef typename Arc::StateId StateId;
     60   typedef typename Arc::Weight Weight;
     61 
     62   RmEpsilonState(const Fst<Arc> &fst,
     63                  vector<Weight> *distance,
     64                  const RmEpsilonOptions<Arc, Queue> &opts)
     65       : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true) {
     66   }
     67 
     68   // Compute arcs and final weight for state 's'
     69   void Expand(StateId s);
     70 
     71   // Returns arcs of expanded state.
     72   vector<Arc> &Arcs() { return arcs_; }
     73 
     74   // Returns final weight of expanded state.
     75   const Weight &Final() const { return final_; }
     76 
     77  private:
     78   struct Element {
     79     Label ilabel;
     80     Label olabel;
     81     StateId nextstate;
     82 
     83     Element() {}
     84 
     85     Element(Label i, Label o, StateId s)
     86         : ilabel(i), olabel(o), nextstate(s) {}
     87   };
     88 
     89   class ElementKey {
     90    public:
     91     size_t operator()(const Element& e) const {
     92       return static_cast<size_t>(e.nextstate);
     93       return static_cast<size_t>(e.nextstate +
     94                                  e.ilabel * kPrime0 +
     95                                  e.olabel * kPrime1);
     96     }
     97 
     98    private:
     99     static const int kPrime0 = 7853;
    100     static const int kPrime1 = 7867;
    101   };
    102 
    103   class ElementEqual {
    104    public:
    105     bool operator()(const Element &e1, const Element &e2) const {
    106       return (e1.ilabel == e2.ilabel) &&  (e1.olabel == e2.olabel)
    107                          && (e1.nextstate == e2.nextstate);
    108     }
    109   };
    110 
    111  private:
    112   typedef hash_map<Element, pair<StateId, ssize_t>,
    113                    ElementKey, ElementEqual> ElementMap;
    114 
    115   const Fst<Arc> &fst_;
    116   // Distance from state being expanded in epsilon-closure.
    117   vector<Weight> *distance_;
    118   // Shortest distance algorithm computation state.
    119   ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_;
    120   // Maps an element 'e' to a pair 'p' corresponding to a position
    121   // in the arcs vector of the state being expanded. 'e' corresponds
    122   // to the position 'p.second' in the 'arcs_' vector if 'p.first' is
    123   // equal to the state being expanded.
    124   ElementMap element_map_;
    125   EpsilonArcFilter<Arc> eps_filter_;
    126   stack<StateId> eps_queue_;      // Queue used to visit the epsilon-closure
    127   vector<bool> visited_;          // '[i] = true' if state 'i' has been visited
    128   slist<StateId> visited_states_; // List of visited states
    129   vector<Arc> arcs_;              // Arcs of state being expanded
    130   Weight final_;                  // Final weight of state being expanded
    131 
    132   void operator=(const RmEpsilonState);  // Disallow
    133 };
    134 
    135 
    136 template <class Arc, class Queue>
    137 void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) {
    138    sd_state_.ShortestDistance(source);
    139    eps_queue_.push(source);
    140    final_ = Weight::Zero();
    141    arcs_.clear();
    142 
    143    while (!eps_queue_.empty()) {
    144      StateId state = eps_queue_.top();
    145      eps_queue_.pop();
    146 
    147      while ((StateId)visited_.size() <= state) visited_.push_back(false);
    148      visited_[state] = true;
    149      visited_states_.push_front(state);
    150 
    151      for (ArcIterator< Fst<Arc> > ait(fst_, state);
    152           !ait.Done();
    153           ait.Next()) {
    154        Arc arc = ait.Value();
    155        arc.weight = Times((*distance_)[state], arc.weight);
    156 
    157        if (eps_filter_(arc)) {
    158          while ((StateId)visited_.size() <= arc.nextstate)
    159            visited_.push_back(false);
    160          if (!visited_[arc.nextstate])
    161            eps_queue_.push(arc.nextstate);
    162        } else {
    163           Element element(arc.ilabel, arc.olabel, arc.nextstate);
    164           typename ElementMap::iterator it = element_map_.find(element);
    165           if (it == element_map_.end()) {
    166             element_map_.insert(
    167                 pair<Element, pair<StateId, ssize_t> >
    168                 (element, pair<StateId, ssize_t>(source, arcs_.size())));
    169             arcs_.push_back(arc);
    170           } else {
    171             if (((*it).second).first == source) {
    172               Weight &w = arcs_[((*it).second).second].weight;
    173               w = Plus(w, arc.weight);
    174             } else {
    175               ((*it).second).first = source;
    176               ((*it).second).second = arcs_.size();
    177               arcs_.push_back(arc);
    178             }
    179           }
    180         }
    181      }
    182      final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
    183    }
    184 
    185    while (!visited_states_.empty()) {
    186      visited_[visited_states_.front()] = false;
    187      visited_states_.pop_front();
    188    }
    189 }
    190 
    191 
    192 // Removes epsilon-transitions (when both the input and output label
    193 // are an epsilon) from a transducer. The result will be an equivalent
    194 // FST that has no such epsilon transitions.  This version modifies
    195 // its input. It allows fine control via the options argument; see
    196 // below for a simpler interface.
    197 //
    198 // The vector 'distance' will be used to hold the shortest distances
    199 // during the epsilon-closure computation. The state queue discipline
    200 // and convergence delta are taken in the options argument.
    201 template <class Arc, class Queue>
    202 void RmEpsilon(MutableFst<Arc> *fst,
    203                vector<typename Arc::Weight> *distance,
    204                const RmEpsilonOptions<Arc, Queue> &opts) {
    205   typedef typename Arc::StateId StateId;
    206   typedef typename Arc::Weight Weight;
    207   typedef typename Arc::Label Label;
    208 
    209   // States sorted in topological order when (acyclic) or generic
    210   // topological order (cyclic).
    211   vector<StateId> states;
    212 
    213   if (fst->Properties(kTopSorted, false) & kTopSorted) {
    214     for (StateId i = 0; i < (StateId)fst->NumStates(); i++)
    215       states.push_back(i);
    216   } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
    217     vector<StateId> order;
    218     bool acyclic;
    219     TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
    220     DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
    221     if (!acyclic)
    222       LOG(FATAL) << "RmEpsilon: not acyclic though property bit is set";
    223     states.resize(order.size());
    224     for (StateId i = 0; i < (StateId)order.size(); i++)
    225       states[order[i]] = i;
    226   } else {
    227      uint64 props;
    228      vector<StateId> scc;
    229      SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props);
    230      DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
    231      vector<StateId> first(scc.size(), kNoStateId);
    232      vector<StateId> next(scc.size(), kNoStateId);
    233      for (StateId i = 0; i < (StateId)scc.size(); i++) {
    234        if (first[scc[i]] != kNoStateId)
    235          next[i] = first[scc[i]];
    236        first[scc[i]] = i;
    237      }
    238      for (StateId i = 0; i < (StateId)first.size(); i++)
    239        for (StateId j = first[i]; j != kNoStateId; j = next[j])
    240          states.push_back(j);
    241   }
    242 
    243   RmEpsilonState<Arc, Queue>
    244     rmeps_state(*fst, distance, opts);
    245 
    246   while (!states.empty()) {
    247     StateId state = states.back();
    248     states.pop_back();
    249     rmeps_state.Expand(state);
    250     fst->SetFinal(state, rmeps_state.Final());
    251     fst->DeleteArcs(state);
    252     vector<Arc> &arcs = rmeps_state.Arcs();
    253     while (!arcs.empty()) {
    254       fst->AddArc(state, arcs.back());
    255       arcs.pop_back();
    256     }
    257   }
    258 
    259   fst->SetProperties(RmEpsilonProperties(
    260                          fst->Properties(kFstProperties, false)),
    261                      kFstProperties);
    262 
    263   if (opts.connect)
    264     Connect(fst);
    265 }
    266 
    267 
    268 // Removes epsilon-transitions (when both the input and output label
    269 // are an epsilon) from a transducer. The result will be an equivalent
    270 // FST that has no such epsilon transitions. This version modifies its
    271 // input. It has a simplified interface; see above for a version that
    272 // allows finer control.
    273 //
    274 // Complexity:
    275 // - Time:
    276 //   - Unweighted: O(V2 + V E)
    277 //   - Acyclic: O(V2 + V E)
    278 //   - Tropical semiring: O(V2 log V + V E)
    279 //   - General: exponential
    280 // - Space: O(V E)
    281 // where V = # of states visited, E = # of arcs.
    282 //
    283 // References:
    284 // - Mehryar Mohri. Generic Epsilon-Removal and Input
    285 //   Epsilon-Normalization Algorithms for Weighted Transducers,
    286 //   "International Journal of Computer Science", 13(1):129-143 (2002).
    287 template <class Arc>
    288 void RmEpsilon(MutableFst<Arc> *fst, bool connect = true) {
    289   typedef typename Arc::StateId StateId;
    290   typedef typename Arc::Weight Weight;
    291   typedef typename Arc::Label Label;
    292 
    293   vector<Weight> distance;
    294   AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
    295   RmEpsilonOptions<Arc, AutoQueue<StateId> >
    296     opts(&state_queue, kDelta, connect);
    297 
    298   RmEpsilon(fst, &distance, opts);
    299 }
    300 
    301 
    302 struct RmEpsilonFstOptions : CacheOptions {
    303   float delta;
    304 
    305   RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
    306       : CacheOptions(opts), delta(delta) {}
    307 
    308   explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
    309 };
    310 
    311 
    312 // Implementation of delayed RmEpsilonFst.
    313 template <class A>
    314 class RmEpsilonFstImpl : public CacheImpl<A> {
    315  public:
    316   using FstImpl<A>::SetType;
    317   using FstImpl<A>::SetProperties;
    318   using FstImpl<A>::Properties;
    319   using FstImpl<A>::SetInputSymbols;
    320   using FstImpl<A>::SetOutputSymbols;
    321 
    322   using CacheBaseImpl< CacheState<A> >::HasStart;
    323   using CacheBaseImpl< CacheState<A> >::HasFinal;
    324   using CacheBaseImpl< CacheState<A> >::HasArcs;
    325 
    326   typedef typename A::Label Label;
    327   typedef typename A::Weight Weight;
    328   typedef typename A::StateId StateId;
    329   typedef CacheState<A> State;
    330 
    331   RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts)
    332       : CacheImpl<A>(opts),
    333         fst_(fst.Copy()),
    334         rmeps_state_(
    335             *fst_,
    336             &distance_,
    337             RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, opts.delta,
    338                                                      false)
    339             ) {
    340     SetType("rmepsilon");
    341     uint64 props = fst.Properties(kFstProperties, false);
    342     SetProperties(RmEpsilonProperties(props, true), kCopyProperties);
    343   }
    344 
    345   ~RmEpsilonFstImpl() {
    346     delete fst_;
    347   }
    348 
    349   StateId Start() {
    350     if (!HasStart()) {
    351       SetStart(fst_->Start());
    352     }
    353     return CacheImpl<A>::Start();
    354   }
    355 
    356   Weight Final(StateId s) {
    357     if (!HasFinal(s)) {
    358       Expand(s);
    359     }
    360     return CacheImpl<A>::Final(s);
    361   }
    362 
    363   size_t NumArcs(StateId s) {
    364     if (!HasArcs(s))
    365       Expand(s);
    366     return CacheImpl<A>::NumArcs(s);
    367   }
    368 
    369   size_t NumInputEpsilons(StateId s) {
    370     if (!HasArcs(s))
    371       Expand(s);
    372     return CacheImpl<A>::NumInputEpsilons(s);
    373   }
    374 
    375   size_t NumOutputEpsilons(StateId s) {
    376     if (!HasArcs(s))
    377       Expand(s);
    378     return CacheImpl<A>::NumOutputEpsilons(s);
    379   }
    380 
    381   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    382     if (!HasArcs(s))
    383       Expand(s);
    384     CacheImpl<A>::InitArcIterator(s, data);
    385   }
    386 
    387   void Expand(StateId s) {
    388     rmeps_state_.Expand(s);
    389     SetFinal(s, rmeps_state_.Final());
    390     vector<A> &arcs = rmeps_state_.Arcs();
    391     while (!arcs.empty()) {
    392       AddArc(s, arcs.back());
    393       arcs.pop_back();
    394     }
    395     SetArcs(s);
    396   }
    397 
    398  private:
    399   const Fst<A> *fst_;
    400   vector<Weight> distance_;
    401   FifoQueue<StateId> queue_;
    402   RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_;
    403 
    404   DISALLOW_EVIL_CONSTRUCTORS(RmEpsilonFstImpl);
    405 };
    406 
    407 
    408 // Removes epsilon-transitions (when both the input and output label
    409 // are an epsilon) from a transducer. The result will be an equivalent
    410 // FST that has no such epsilon transitions.  This version is a
    411 // delayed Fst.
    412 //
    413 // Complexity:
    414 // - Time:
    415 //   - Unweighted: O(v^2 + v e)
    416 //   - General: exponential
    417 // - Space: O(v e)
    418 // where v = # of states visited, e = # of arcs visited. Constant time
    419 // to visit an input state or arc is assumed and exclusive of caching.
    420 //
    421 // References:
    422 // - Mehryar Mohri. Generic Epsilon-Removal and Input
    423 //   Epsilon-Normalization Algorithms for Weighted Transducers,
    424 //   "International Journal of Computer Science", 13(1):129-143 (2002).
    425 template <class A>
    426 class RmEpsilonFst : public Fst<A> {
    427  public:
    428   friend class ArcIterator< RmEpsilonFst<A> >;
    429   friend class CacheStateIterator< RmEpsilonFst<A> >;
    430   friend class CacheArcIterator< RmEpsilonFst<A> >;
    431 
    432   typedef A Arc;
    433   typedef typename A::Weight Weight;
    434   typedef typename A::StateId StateId;
    435   typedef CacheState<A> State;
    436 
    437   RmEpsilonFst(const Fst<A> &fst)
    438       : impl_(new RmEpsilonFstImpl<A>(fst, RmEpsilonFstOptions())) {}
    439 
    440   RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
    441       : impl_(new RmEpsilonFstImpl<A>(fst, opts)) {}
    442 
    443   explicit RmEpsilonFst(const RmEpsilonFst<A> &fst) : impl_(fst.impl_) {
    444     impl_->IncrRefCount();
    445   }
    446 
    447   virtual ~RmEpsilonFst() { if (!impl_->DecrRefCount()) delete impl_;  }
    448 
    449   virtual StateId Start() const { return impl_->Start(); }
    450 
    451   virtual Weight Final(StateId s) const { return impl_->Final(s); }
    452 
    453   virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
    454 
    455   virtual size_t NumInputEpsilons(StateId s) const {
    456     return impl_->NumInputEpsilons(s);
    457   }
    458 
    459   virtual size_t NumOutputEpsilons(StateId s) const {
    460     return impl_->NumOutputEpsilons(s);
    461   }
    462 
    463   virtual uint64 Properties(uint64 mask, bool test) const {
    464     if (test) {
    465       uint64 known, test = TestProperties(*this, mask, &known);
    466       impl_->SetProperties(test, known);
    467       return test & mask;
    468     } else {
    469       return impl_->Properties(mask);
    470     }
    471   }
    472 
    473   virtual const string& Type() const { return impl_->Type(); }
    474 
    475   virtual RmEpsilonFst<A> *Copy() const {
    476     return new RmEpsilonFst<A>(*this);
    477   }
    478 
    479   virtual const SymbolTable* InputSymbols() const {
    480     return impl_->InputSymbols();
    481   }
    482 
    483   virtual const SymbolTable* OutputSymbols() const {
    484     return impl_->OutputSymbols();
    485   }
    486 
    487   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
    488 
    489   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    490     impl_->InitArcIterator(s, data);
    491   }
    492 
    493  protected:
    494   RmEpsilonFstImpl<A> *Impl() { return impl_; }
    495 
    496  private:
    497   RmEpsilonFstImpl<A> *impl_;
    498 
    499   void operator=(const RmEpsilonFst<A> &fst);  // disallow
    500 };
    501 
    502 
    503 // Specialization for RmEpsilonFst.
    504 template<class A>
    505 class StateIterator< RmEpsilonFst<A> >
    506     : public CacheStateIterator< RmEpsilonFst<A> > {
    507  public:
    508   explicit StateIterator(const RmEpsilonFst<A> &fst)
    509       : CacheStateIterator< RmEpsilonFst<A> >(fst) {}
    510 };
    511 
    512 
    513 // Specialization for RmEpsilonFst.
    514 template <class A>
    515 class ArcIterator< RmEpsilonFst<A> >
    516     : public CacheArcIterator< RmEpsilonFst<A> > {
    517  public:
    518   typedef typename A::StateId StateId;
    519 
    520   ArcIterator(const RmEpsilonFst<A> &fst, StateId s)
    521       : CacheArcIterator< RmEpsilonFst<A> >(fst, s) {
    522     if (!fst.impl_->HasArcs(s))
    523       fst.impl_->Expand(s);
    524   }
    525 
    526  private:
    527   DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
    528 };
    529 
    530 
    531 template <class A> inline
    532 void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
    533   data->base = new StateIterator< RmEpsilonFst<A> >(*this);
    534 }
    535 
    536 
    537 // Useful alias when using StdArc.
    538 typedef RmEpsilonFst<StdArc> StdRmEpsilonFst;
    539 
    540 }  // namespace fst
    541 
    542 #endif  // FST_LIB_RMEPSILON_H__
    543