Home | History | Annotate | Download | only in fst
      1 // minimize.h
      2 // minimize.h
      3 
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //     http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 //
     16 // Copyright 2005-2010 Google, Inc.
     17 // Author: johans (at) google.com (Johan Schalkwyk)
     18 //
     19 // \file Functions and classes to minimize a finite state acceptor
     20 //
     21 
     22 #ifndef FST_LIB_MINIMIZE_H__
     23 #define FST_LIB_MINIMIZE_H__
     24 
     25 #include <cmath>
     26 
     27 #include <algorithm>
     28 #include <map>
     29 #include <queue>
     30 #include <vector>
     31 using std::vector;
     32 
     33 #include <fst/arcsort.h>
     34 #include <fst/connect.h>
     35 #include <fst/dfs-visit.h>
     36 #include <fst/encode.h>
     37 #include <fst/factor-weight.h>
     38 #include <fst/fst.h>
     39 #include <fst/mutable-fst.h>
     40 #include <fst/partition.h>
     41 #include <fst/push.h>
     42 #include <fst/queue.h>
     43 #include <fst/reverse.h>
     44 #include <fst/state-map.h>
     45 
     46 
     47 namespace fst {
     48 
     49 // comparator for creating partition based on sorting on
     50 // - states
     51 // - final weight
     52 // - out degree,
     53 // -  (input label, output label, weight, destination_block)
     54 template <class A>
     55 class StateComparator {
     56  public:
     57   typedef typename A::StateId StateId;
     58   typedef typename A::Weight Weight;
     59 
     60   static const uint32 kCompareFinal     = 0x00000001;
     61   static const uint32 kCompareOutDegree = 0x00000002;
     62   static const uint32 kCompareArcs      = 0x00000004;
     63   static const uint32 kCompareAll       = 0x00000007;
     64 
     65   StateComparator(const Fst<A>& fst,
     66                   const Partition<typename A::StateId>& partition,
     67                   uint32 flags = kCompareAll)
     68       : fst_(fst), partition_(partition), flags_(flags) {}
     69 
     70   // compare state x with state y based on sort criteria
     71   bool operator()(const StateId x, const StateId y) const {
     72     // check for final state equivalence
     73     if (flags_ & kCompareFinal) {
     74       const size_t xfinal = fst_.Final(x).Hash();
     75       const size_t yfinal = fst_.Final(y).Hash();
     76       if      (xfinal < yfinal) return true;
     77       else if (xfinal > yfinal) return false;
     78     }
     79 
     80     if (flags_ & kCompareOutDegree) {
     81       // check for # arcs
     82       if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true;
     83       if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false;
     84 
     85       if (flags_ & kCompareArcs) {
     86         // # arcs are equal, check for arc match
     87         for (ArcIterator<Fst<A> > aiter1(fst_, x), aiter2(fst_, y);
     88              !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
     89           const A& arc1 = aiter1.Value();
     90           const A& arc2 = aiter2.Value();
     91           if (arc1.ilabel < arc2.ilabel) return true;
     92           if (arc1.ilabel > arc2.ilabel) return false;
     93 
     94           if (partition_.class_id(arc1.nextstate) <
     95               partition_.class_id(arc2.nextstate)) return true;
     96           if (partition_.class_id(arc1.nextstate) >
     97               partition_.class_id(arc2.nextstate)) return false;
     98         }
     99       }
    100     }
    101 
    102     return false;
    103   }
    104 
    105  private:
    106   const Fst<A>& fst_;
    107   const Partition<typename A::StateId>& partition_;
    108   const uint32 flags_;
    109 };
    110 
    111 template <class A> const uint32 StateComparator<A>::kCompareFinal;
    112 template <class A> const uint32 StateComparator<A>::kCompareOutDegree;
    113 template <class A> const uint32 StateComparator<A>::kCompareArcs;
    114 template <class A> const uint32 StateComparator<A>::kCompareAll;
    115 
    116 
    117 // Computes equivalence classes for cyclic Fsts. For cyclic minimization
    118 // we use the classic HopCroft minimization algorithm, which is of
    119 //
    120 //   O(E)log(N),
    121 //
    122 // where E is the number of edges in the machine and N is number of states.
    123 //
    124 // The following paper describes the original algorithm
    125 //  An N Log N algorithm for minimizing states in a finite automaton
    126 //  by John HopCroft, January 1971
    127 //
    128 template <class A, class Queue>
    129 class CyclicMinimizer {
    130  public:
    131   typedef typename A::Label Label;
    132   typedef typename A::StateId StateId;
    133   typedef typename A::StateId ClassId;
    134   typedef typename A::Weight Weight;
    135   typedef ReverseArc<A> RevA;
    136 
    137   CyclicMinimizer(const ExpandedFst<A>& fst) {
    138     Initialize(fst);
    139     Compute(fst);
    140   }
    141 
    142   ~CyclicMinimizer() {
    143     delete aiter_queue_;
    144   }
    145 
    146   const Partition<StateId>& partition() const {
    147     return P_;
    148   }
    149 
    150   // helper classes
    151  private:
    152   typedef ArcIterator<Fst<RevA> > ArcIter;
    153   class ArcIterCompare {
    154    public:
    155     ArcIterCompare(const Partition<StateId>& partition)
    156         : partition_(partition) {}
    157 
    158     ArcIterCompare(const ArcIterCompare& comp)
    159         : partition_(comp.partition_) {}
    160 
    161     // compare two iterators based on there input labels, and proto state
    162     // (partition class Ids)
    163     bool operator()(const ArcIter* x, const ArcIter* y) const {
    164       const RevA& xarc = x->Value();
    165       const RevA& yarc = y->Value();
    166       return (xarc.ilabel > yarc.ilabel);
    167     }
    168 
    169    private:
    170     const Partition<StateId>& partition_;
    171   };
    172 
    173   typedef priority_queue<ArcIter*, vector<ArcIter*>, ArcIterCompare>
    174   ArcIterQueue;
    175 
    176   // helper methods
    177  private:
    178   // prepartitions the space into equivalence classes with
    179   //   same final weight
    180   //   same # arcs per state
    181   //   same outgoing arcs
    182   void PrePartition(const Fst<A>& fst) {
    183     VLOG(5) << "PrePartition";
    184 
    185     typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
    186     StateComparator<A> comp(fst, P_, StateComparator<A>::kCompareFinal);
    187     EquivalenceMap equiv_map(comp);
    188 
    189     StateIterator<Fst<A> > siter(fst);
    190     StateId class_id = P_.AddClass();
    191     P_.Add(siter.Value(), class_id);
    192     equiv_map[siter.Value()] = class_id;
    193     L_.Enqueue(class_id);
    194     for (siter.Next(); !siter.Done(); siter.Next()) {
    195       StateId  s = siter.Value();
    196       typename EquivalenceMap::const_iterator it = equiv_map.find(s);
    197       if (it == equiv_map.end()) {
    198         class_id = P_.AddClass();
    199         P_.Add(s, class_id);
    200         equiv_map[s] = class_id;
    201         L_.Enqueue(class_id);
    202       } else {
    203         P_.Add(s, it->second);
    204         equiv_map[s] = it->second;
    205       }
    206     }
    207 
    208     VLOG(5) << "Initial Partition: " << P_.num_classes();
    209   }
    210 
    211   // - Create inverse transition Tr_ = rev(fst)
    212   // - loop over states in fst and split on final, creating two blocks
    213   //   in the partition corresponding to final, non-final
    214   void Initialize(const Fst<A>& fst) {
    215     // construct Tr
    216     Reverse(fst, &Tr_);
    217     ILabelCompare<RevA> ilabel_comp;
    218     ArcSort(&Tr_, ilabel_comp);
    219 
    220     // initial split (F, S - F)
    221     P_.Initialize(Tr_.NumStates() - 1);
    222 
    223     // prep partition
    224     PrePartition(fst);
    225 
    226     // allocate arc iterator queue
    227     ArcIterCompare comp(P_);
    228     aiter_queue_ = new ArcIterQueue(comp);
    229   }
    230 
    231   // partition all classes with destination C
    232   void Split(ClassId C) {
    233     // Prep priority queue. Open arc iterator for each state in C, and
    234     // insert into priority queue.
    235     for (PartitionIterator<StateId> siter(P_, C);
    236          !siter.Done(); siter.Next()) {
    237       StateId s = siter.Value();
    238       if (Tr_.NumArcs(s + 1))
    239         aiter_queue_->push(new ArcIterator<Fst<RevA> >(Tr_, s + 1));
    240     }
    241 
    242     // Now pop arc iterator from queue, split entering equivalence class
    243     // re-insert updated iterator into queue.
    244     Label prev_label = -1;
    245     while (!aiter_queue_->empty()) {
    246       ArcIterator<Fst<RevA> >* aiter = aiter_queue_->top();
    247       aiter_queue_->pop();
    248       if (aiter->Done()) {
    249         delete aiter;
    250         continue;
    251      }
    252 
    253       const RevA& arc = aiter->Value();
    254       StateId from_state = aiter->Value().nextstate - 1;
    255       Label   from_label = arc.ilabel;
    256       if (prev_label != from_label)
    257         P_.FinalizeSplit(&L_);
    258 
    259       StateId from_class = P_.class_id(from_state);
    260       if (P_.class_size(from_class) > 1)
    261         P_.SplitOn(from_state);
    262 
    263       prev_label = from_label;
    264       aiter->Next();
    265       if (aiter->Done())
    266         delete aiter;
    267       else
    268         aiter_queue_->push(aiter);
    269     }
    270     P_.FinalizeSplit(&L_);
    271   }
    272 
    273   // Main loop for hopcroft minimization.
    274   void Compute(const Fst<A>& fst) {
    275     // process active classes (FIFO, or FILO)
    276     while (!L_.Empty()) {
    277       ClassId C = L_.Head();
    278       L_.Dequeue();
    279 
    280       // split on C, all labels in C
    281       Split(C);
    282     }
    283   }
    284 
    285   // helper data
    286  private:
    287   // Partioning of states into equivalence classes
    288   Partition<StateId> P_;
    289 
    290   // L = set of active classes to be processed in partition P
    291   Queue L_;
    292 
    293   // reverse transition function
    294   VectorFst<RevA> Tr_;
    295 
    296   // Priority queue of open arc iterators for all states in the 'splitter'
    297   // equivalence class
    298   ArcIterQueue* aiter_queue_;
    299 };
    300 
    301 
    302 // Computes equivalence classes for acyclic Fsts. The implementation details
    303 // for this algorithms is documented by the following paper.
    304 //
    305 // Minimization of acyclic deterministic automata in linear time
    306 //  Dominque Revuz
    307 //
    308 // Complexity O(|E|)
    309 //
    310 template <class A>
    311 class AcyclicMinimizer {
    312  public:
    313   typedef typename A::Label Label;
    314   typedef typename A::StateId StateId;
    315   typedef typename A::StateId ClassId;
    316   typedef typename A::Weight Weight;
    317 
    318   AcyclicMinimizer(const ExpandedFst<A>& fst) {
    319     Initialize(fst);
    320     Refine(fst);
    321   }
    322 
    323   const Partition<StateId>& partition() {
    324     return partition_;
    325   }
    326 
    327   // helper classes
    328  private:
    329   // DFS visitor to compute the height (distance) to final state.
    330   class HeightVisitor {
    331    public:
    332     HeightVisitor() : max_height_(0), num_states_(0) { }
    333 
    334     // invoked before dfs visit
    335     void InitVisit(const Fst<A>& fst) {}
    336 
    337     // invoked when state is discovered (2nd arg is DFS tree root)
    338     bool InitState(StateId s, StateId root) {
    339       // extend height array and initialize height (distance) to 0
    340       for (size_t i = height_.size(); i <= s; ++i)
    341         height_.push_back(-1);
    342 
    343       if (s >= num_states_) num_states_ = s + 1;
    344       return true;
    345     }
    346 
    347     // invoked when tree arc examined (to undiscoverted state)
    348     bool TreeArc(StateId s, const A& arc) {
    349       return true;
    350     }
    351 
    352     // invoked when back arc examined (to unfinished state)
    353     bool BackArc(StateId s, const A& arc) {
    354       return true;
    355     }
    356 
    357     // invoked when forward or cross arc examined (to finished state)
    358     bool ForwardOrCrossArc(StateId s, const A& arc) {
    359       if (height_[arc.nextstate] + 1 > height_[s])
    360         height_[s] = height_[arc.nextstate] + 1;
    361       return true;
    362     }
    363 
    364     // invoked when state finished (parent is kNoStateId for tree root)
    365     void FinishState(StateId s, StateId parent, const A* parent_arc) {
    366       if (height_[s] == -1) height_[s] = 0;
    367       StateId h = height_[s] +  1;
    368       if (parent >= 0) {
    369         if (h > height_[parent]) height_[parent] = h;
    370         if (h > max_height_)     max_height_ = h;
    371       }
    372     }
    373 
    374     // invoked after DFS visit
    375     void FinishVisit() {}
    376 
    377     size_t max_height() const { return max_height_; }
    378 
    379     const vector<StateId>& height() const { return height_; }
    380 
    381     const size_t num_states() const { return num_states_; }
    382 
    383    private:
    384     vector<StateId> height_;
    385     size_t max_height_;
    386     size_t num_states_;
    387   };
    388 
    389   // helper methods
    390  private:
    391   // cluster states according to height (distance to final state)
    392   void Initialize(const Fst<A>& fst) {
    393     // compute height (distance to final state)
    394     HeightVisitor hvisitor;
    395     DfsVisit(fst, &hvisitor);
    396 
    397     // create initial partition based on height
    398     partition_.Initialize(hvisitor.num_states());
    399     partition_.AllocateClasses(hvisitor.max_height() + 1);
    400     const vector<StateId>& hstates = hvisitor.height();
    401     for (size_t s = 0; s < hstates.size(); ++s)
    402       partition_.Add(s, hstates[s]);
    403   }
    404 
    405   // refine states based on arc sort (out degree, arc equivalence)
    406   void Refine(const Fst<A>& fst) {
    407     typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
    408     StateComparator<A> comp(fst, partition_);
    409 
    410     // start with tail (height = 0)
    411     size_t height = partition_.num_classes();
    412     for (size_t h = 0; h < height; ++h) {
    413       EquivalenceMap equiv_classes(comp);
    414 
    415       // sort states within equivalence class
    416       PartitionIterator<StateId> siter(partition_, h);
    417       equiv_classes[siter.Value()] = h;
    418       for (siter.Next(); !siter.Done(); siter.Next()) {
    419         const StateId s = siter.Value();
    420         typename EquivalenceMap::const_iterator it = equiv_classes.find(s);
    421         if (it == equiv_classes.end())
    422           equiv_classes[s] = partition_.AddClass();
    423         else
    424           equiv_classes[s] = it->second;
    425       }
    426 
    427       // create refined partition
    428       for (siter.Reset(); !siter.Done();) {
    429         const StateId s = siter.Value();
    430         const StateId old_class = partition_.class_id(s);
    431         const StateId new_class = equiv_classes[s];
    432 
    433         // a move operation can invalidate the iterator, so
    434         // we first update the iterator to the next element
    435         // before we move the current element out of the list
    436         siter.Next();
    437         if (old_class != new_class)
    438           partition_.Move(s, new_class);
    439       }
    440     }
    441   }
    442 
    443  private:
    444   Partition<StateId> partition_;
    445 };
    446 
    447 
    448 // Given a partition and a mutable fst, merge states of Fst inplace
    449 // (i.e. destructively). Merging works by taking the first state in
    450 // a class of the partition to be the representative state for the class.
    451 // Each arc is then reconnected to this state. All states in the class
    452 // are merged by adding there arcs to the representative state.
    453 template <class A>
    454 void MergeStates(
    455     const Partition<typename A::StateId>& partition, MutableFst<A>* fst) {
    456   typedef typename A::StateId StateId;
    457 
    458   vector<StateId> state_map(partition.num_classes());
    459   for (size_t i = 0; i < partition.num_classes(); ++i) {
    460     PartitionIterator<StateId> siter(partition, i);
    461     state_map[i] = siter.Value();  // first state in partition;
    462   }
    463 
    464   // relabel destination states
    465   for (size_t c = 0; c < partition.num_classes(); ++c) {
    466     for (PartitionIterator<StateId> siter(partition, c);
    467          !siter.Done(); siter.Next()) {
    468       StateId s = siter.Value();
    469       for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
    470            !aiter.Done(); aiter.Next()) {
    471         A arc = aiter.Value();
    472         arc.nextstate = state_map[partition.class_id(arc.nextstate)];
    473 
    474         if (s == state_map[c])  // first state just set destination
    475           aiter.SetValue(arc);
    476         else
    477           fst->AddArc(state_map[c], arc);
    478       }
    479     }
    480   }
    481   fst->SetStart(state_map[partition.class_id(fst->Start())]);
    482 
    483   Connect(fst);
    484 }
    485 
    486 template <class A>
    487 void AcceptorMinimize(MutableFst<A>* fst) {
    488   typedef typename A::StateId StateId;
    489   if (!(fst->Properties(kAcceptor | kUnweighted, true))) {
    490     FSTERROR() << "FST is not an unweighted acceptor";
    491     fst->SetProperties(kError, kError);
    492     return;
    493   }
    494 
    495   // connect fst before minimization, handles disconnected states
    496   Connect(fst);
    497   if (fst->NumStates() == 0) return;
    498 
    499   if (fst->Properties(kAcyclic, true)) {
    500     // Acyclic minimization (revuz)
    501     VLOG(2) << "Acyclic Minimization";
    502     ArcSort(fst, ILabelCompare<A>());
    503     AcyclicMinimizer<A> minimizer(*fst);
    504     MergeStates(minimizer.partition(), fst);
    505 
    506   } else {
    507     // Cyclic minimizaton (hopcroft)
    508     VLOG(2) << "Cyclic Minimization";
    509     CyclicMinimizer<A, LifoQueue<StateId> > minimizer(*fst);
    510     MergeStates(minimizer.partition(), fst);
    511   }
    512 
    513   // Merge in appropriate semiring
    514   ArcUniqueMapper<A> mapper(*fst);
    515   StateMap(fst, mapper);
    516 }
    517 
    518 
    519 // In place minimization of deterministic weighted automata and transducers.
    520 // For transducers, then the 'sfst' argument is not null, the algorithm
    521 // produces a compact factorization of the minimal transducer.
    522 //
    523 // In the acyclic case, we use an algorithm from Dominique Revuz that
    524 // is linear in the number of arcs (edges) in the machine.
    525 //  Complexity = O(E)
    526 //
    527 // In the cyclic case, we use the classical hopcroft minimization.
    528 //  Complexity = O(|E|log(|N|)
    529 //
    530 template <class A>
    531 void Minimize(MutableFst<A>* fst,
    532               MutableFst<A>* sfst = 0,
    533               float delta = kDelta) {
    534   uint64 props = fst->Properties(kAcceptor | kIDeterministic|
    535                                  kWeighted | kUnweighted, true);
    536   if (!(props & kIDeterministic)) {
    537     FSTERROR() << "FST is not deterministic";
    538     fst->SetProperties(kError, kError);
    539     return;
    540   }
    541 
    542   if (!(props & kAcceptor)) {  // weighted transducer
    543     VectorFst< GallicArc<A, STRING_LEFT> > gfst;
    544     ArcMap(*fst, &gfst, ToGallicMapper<A, STRING_LEFT>());
    545     fst->DeleteStates();
    546     gfst.SetProperties(kAcceptor, kAcceptor);
    547     Push(&gfst, REWEIGHT_TO_INITIAL, delta);
    548     ArcMap(&gfst, QuantizeMapper< GallicArc<A, STRING_LEFT> >(delta));
    549     EncodeMapper< GallicArc<A, STRING_LEFT> >
    550       encoder(kEncodeLabels | kEncodeWeights, ENCODE);
    551     Encode(&gfst, &encoder);
    552     AcceptorMinimize(&gfst);
    553     Decode(&gfst, encoder);
    554 
    555     if (sfst == 0) {
    556       FactorWeightFst< GallicArc<A, STRING_LEFT>,
    557         GallicFactor<typename A::Label,
    558         typename A::Weight, STRING_LEFT> > fwfst(gfst);
    559       SymbolTable *osyms = fst->OutputSymbols() ?
    560           fst->OutputSymbols()->Copy() : 0;
    561       ArcMap(fwfst, fst, FromGallicMapper<A, STRING_LEFT>());
    562       fst->SetOutputSymbols(osyms);
    563       delete osyms;
    564     } else {
    565       sfst->SetOutputSymbols(fst->OutputSymbols());
    566       GallicToNewSymbolsMapper<A, STRING_LEFT> mapper(sfst);
    567       ArcMap(gfst, fst, &mapper);
    568       fst->SetOutputSymbols(sfst->InputSymbols());
    569     }
    570   } else if (props & kWeighted) {  // weighted acceptor
    571     Push(fst, REWEIGHT_TO_INITIAL, delta);
    572     ArcMap(fst, QuantizeMapper<A>(delta));
    573     EncodeMapper<A> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
    574     Encode(fst, &encoder);
    575     AcceptorMinimize(fst);
    576     Decode(fst, encoder);
    577   } else {  // unweighted acceptor
    578     AcceptorMinimize(fst);
    579   }
    580 }
    581 
    582 }  // namespace fst
    583 
    584 #endif  // FST_LIB_MINIMIZE_H__
    585