Home | History | Annotate | Download | only in lib
      1 // minimize.h
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //      http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 //
     16 // \file Functions and classes to minimize a finite state acceptor
     17 
     18 #ifndef FST_LIB_MINIMIZE_H__
     19 #define FST_LIB_MINIMIZE_H__
     20 
     21 #include <algorithm>
     22 #include <map>
     23 #include <queue>
     24 
     25 #include "fst/lib/arcsort.h"
     26 #include "fst/lib/arcsum.h"
     27 #include "fst/lib/connect.h"
     28 #include "fst/lib/dfs-visit.h"
     29 #include "fst/lib/encode.h"
     30 #include "fst/lib/factor-weight.h"
     31 #include "fst/lib/fst.h"
     32 #include "fst/lib/mutable-fst.h"
     33 #include "fst/lib/partition.h"
     34 #include "fst/lib/push.h"
     35 #include "fst/lib/queue.h"
     36 #include "fst/lib/reverse.h"
     37 
     38 namespace fst {
     39 
     40 // comparator for creating partition based on sorting on
     41 // - states
     42 // - final weight
     43 // - out degree,
     44 // -  (input label, output label, weight, destination_block)
     45 template <class A>
     46 class StateComparator {
     47  public:
     48   typedef typename A::StateId StateId;
     49   typedef typename A::Weight Weight;
     50 
     51   static const int32 kCompareFinal     = 0x0000001;
     52   static const int32 kCompareOutDegree = 0x0000002;
     53   static const int32 kCompareArcs      = 0x0000004;
     54   static const int32 kCompareAll       = (kCompareFinal |
     55                                           kCompareOutDegree |
     56                                           kCompareArcs);
     57 
     58   StateComparator(const Fst<A>& fst,
     59                   const Partition<typename A::StateId>& partition,
     60                   int32 flags = kCompareAll)
     61       : fst_(fst), partition_(partition), flags_(flags) {}
     62 
     63   // compare state x with state y based on sort criteria
     64   bool operator()(const StateId x, const StateId y) const {
     65     // check for final state equivalence
     66     if (flags_ & kCompareFinal) {
     67       const ssize_t xfinal = fst_.Final(x).Hash();
     68       const ssize_t yfinal = fst_.Final(y).Hash();
     69       if      (xfinal < yfinal) return true;
     70       else if (xfinal > yfinal) return false;
     71     }
     72 
     73     if (flags_ & kCompareOutDegree) {
     74       // check for # arcs
     75       if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true;
     76       if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false;
     77 
     78       if (flags_ & kCompareArcs) {
     79         // # arcs are equal, check for arc match
     80         for (ArcIterator<Fst<A> > aiter1(fst_, x), aiter2(fst_, y);
     81              !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
     82           const A& arc1 = aiter1.Value();
     83           const A& arc2 = aiter2.Value();
     84           if (arc1.ilabel < arc2.ilabel) return true;
     85           if (arc1.ilabel > arc2.ilabel) return false;
     86 
     87           if (partition_.class_id(arc1.nextstate) <
     88               partition_.class_id(arc2.nextstate)) return true;
     89           if (partition_.class_id(arc1.nextstate) >
     90               partition_.class_id(arc2.nextstate)) return false;
     91         }
     92       }
     93     }
     94 
     95     return false;
     96   }
     97 
     98  private:
     99   const Fst<A>& fst_;
    100   const Partition<typename A::StateId>& partition_;
    101   const int32 flags_;
    102 };
    103 
    104 // Computes equivalence classes for cyclic Fsts. For cyclic minimization
    105 // we use the classic HopCroft minimization algorithm, which is of
    106 //
    107 //   O(E)log(N),
    108 //
    109 // where E is the number of edges in the machine and N is number of states.
    110 //
    111 // The following paper describes the original algorithm
    112 //  An N Log N algorithm for minimizing states in a finite automaton
    113 //  by John HopCroft, January 1971
    114 //
    115 template <class A, class Queue>
    116 class CyclicMinimizer {
    117  public:
    118   typedef typename A::Label Label;
    119   typedef typename A::StateId StateId;
    120   typedef typename A::StateId ClassId;
    121   typedef typename A::Weight Weight;
    122   typedef ReverseArc<A> RevA;
    123 
    124   CyclicMinimizer(const ExpandedFst<A>& fst) {
    125     Initialize(fst);
    126     Compute(fst);
    127   }
    128 
    129   ~CyclicMinimizer() {
    130     delete aiter_queue_;
    131   }
    132 
    133   const Partition<StateId>& partition() const {
    134     return P_;
    135   }
    136 
    137   // helper classes
    138  private:
    139   typedef ArcIterator<Fst<RevA> > ArcIter;
    140   class ArcIterCompare {
    141    public:
    142     ArcIterCompare(const Partition<StateId>& partition)
    143         : partition_(partition) {}
    144 
    145     ArcIterCompare(const ArcIterCompare& comp)
    146         : partition_(comp.partition_) {}
    147 
    148     // compare two iterators based on there input labels, and proto state
    149     // (partition class Ids)
    150     bool operator()(const ArcIter* x, const ArcIter* y) const {
    151       const RevA& xarc = x->Value();
    152       const RevA& yarc = y->Value();
    153       return (xarc.ilabel > yarc.ilabel);
    154     }
    155 
    156    private:
    157     const Partition<StateId>& partition_;
    158   };
    159 
    160   typedef priority_queue<ArcIter*, vector<ArcIter*>, ArcIterCompare>
    161   ArcIterQueue;
    162 
    163   // helper methods
    164  private:
    165   // prepartitions the space into equivalence classes with
    166   //   same final weight
    167   //   same # arcs per state
    168   //   same outgoing arcs
    169   void PrePartition(const Fst<A>& fst) {
    170     VLOG(5) << "PrePartition";
    171 
    172     typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
    173     StateComparator<A> comp(fst, P_, StateComparator<A>::kCompareFinal);
    174     EquivalenceMap equiv_map(comp);
    175 
    176     StateIterator<Fst<A> > siter(fst);
    177     StateId class_id = P_.AddClass();
    178     P_.Add(siter.Value(), class_id);
    179     equiv_map[siter.Value()] = class_id;
    180     L_.Enqueue(class_id);
    181     for (siter.Next(); !siter.Done(); siter.Next()) {
    182       StateId  s = siter.Value();
    183       typename EquivalenceMap::const_iterator it = equiv_map.find(s);
    184       if (it == equiv_map.end()) {
    185         class_id = P_.AddClass();
    186         P_.Add(s, class_id);
    187         equiv_map[s] = class_id;
    188         L_.Enqueue(class_id);
    189       } else {
    190         P_.Add(s, it->second);
    191         equiv_map[s] = it->second;
    192       }
    193     }
    194 
    195     VLOG(5) << "Initial Partition: " << P_.num_classes();
    196   }
    197 
    198   // - Create inverse transition Tr_ = rev(fst)
    199   // - loop over states in fst and split on final, creating two blocks
    200   //   in the partition corresponding to final, non-final
    201   void Initialize(const Fst<A>& fst) {
    202     // construct Tr
    203     Reverse(fst, &Tr_);
    204     ILabelCompare<RevA> ilabel_comp;
    205     ArcSort(&Tr_, ilabel_comp);
    206 
    207     // initial split (F, S - F)
    208     P_.Initialize(Tr_.NumStates() - 1);
    209 
    210     // prep partition
    211     PrePartition(fst);
    212 
    213     // allocate arc iterator queue
    214     ArcIterCompare comp(P_);
    215     aiter_queue_ = new ArcIterQueue(comp);
    216   }
    217 
    218   // partition all classes with destination C
    219   void Split(ClassId C) {
    220     // Prep priority queue. Open arc iterator for each state in C, and
    221     // insert into priority queue.
    222     for (PartitionIterator<StateId> siter(P_, C);
    223          !siter.Done(); siter.Next()) {
    224       StateId s = siter.Value();
    225       if (Tr_.NumArcs(s + 1))
    226         aiter_queue_->push(new ArcIterator<Fst<RevA> >(Tr_, s + 1));
    227     }
    228 
    229     // Now pop arc iterator from queue, split entering equivalence class
    230     // re-insert updated iterator into queue.
    231     Label prev_label = -1;
    232     while (!aiter_queue_->empty()) {
    233       ArcIterator<Fst<RevA> >* aiter = aiter_queue_->top();
    234       aiter_queue_->pop();
    235       if (aiter->Done()) {
    236         delete aiter;
    237         continue;
    238      }
    239 
    240       const RevA& arc = aiter->Value();
    241       StateId from_state = aiter->Value().nextstate - 1;
    242       Label   from_label = arc.ilabel;
    243       if (prev_label != from_label)
    244         P_.FinalizeSplit(&L_);
    245 
    246       StateId from_class = P_.class_id(from_state);
    247       if (P_.class_size(from_class) > 1)
    248         P_.SplitOn(from_state);
    249 
    250       prev_label = from_label;
    251       aiter->Next();
    252       if (aiter->Done())
    253         delete aiter;
    254       else
    255         aiter_queue_->push(aiter);
    256     }
    257     P_.FinalizeSplit(&L_);
    258   }
    259 
    260   // Main loop for hopcroft minimization.
    261   void Compute(const Fst<A>& fst) {
    262     // process active classes (FIFO, or FILO)
    263     while (!L_.Empty()) {
    264       ClassId C = L_.Head();
    265       L_.Dequeue();
    266 
    267       // split on C, all labels in C
    268       Split(C);
    269     }
    270   }
    271 
    272   // helper data
    273  private:
    274   // Partioning of states into equivalence classes
    275   Partition<StateId> P_;
    276 
    277   // L = set of active classes to be processed in partition P
    278   Queue L_;
    279 
    280   // reverse transition function
    281   VectorFst<RevA> Tr_;
    282 
    283   // Priority queue of open arc iterators for all states in the 'splitter'
    284   // equivalence class
    285   ArcIterQueue* aiter_queue_;
    286 };
    287 
    288 
    289 // Computes equivalence classes for acyclic Fsts. The implementation details
    290 // for this algorithms is documented by the following paper.
    291 //
    292 // Minimization of acyclic deterministic automata in linear time
    293 //  Dominque Revuz
    294 //
    295 // Complexity O(|E|)
    296 //
    297 template <class A>
    298 class AcyclicMinimizer {
    299  public:
    300   typedef typename A::Label Label;
    301   typedef typename A::StateId StateId;
    302   typedef typename A::StateId ClassId;
    303   typedef typename A::Weight Weight;
    304 
    305   AcyclicMinimizer(const ExpandedFst<A>& fst) {
    306     Initialize(fst);
    307     Refine(fst);
    308   }
    309 
    310   const Partition<StateId>& partition() {
    311     return partition_;
    312   }
    313 
    314   // helper classes
    315  private:
    316   // DFS visitor to compute the height (distance) to final state.
    317   class HeightVisitor {
    318    public:
    319     HeightVisitor() : max_height_(0), num_states_(0) { }
    320 
    321     // invoked before dfs visit
    322     void InitVisit(const Fst<A>& fst) {}
    323 
    324     // invoked when state is discovered (2nd arg is DFS tree root)
    325     bool InitState(StateId s, StateId root) {
    326       // extend height array and initialize height (distance) to 0
    327       for (size_t i = height_.size(); i <= (size_t)s; ++i)
    328         height_.push_back(-1);
    329 
    330       if (s >= (StateId)num_states_) num_states_ = s + 1;
    331       return true;
    332     }
    333 
    334     // invoked when tree arc examined (to undiscoverted state)
    335     bool TreeArc(StateId s, const A& arc) {
    336       return true;
    337     }
    338 
    339     // invoked when back arc examined (to unfinished state)
    340     bool BackArc(StateId s, const A& arc) {
    341       return true;
    342     }
    343 
    344     // invoked when forward or cross arc examined (to finished state)
    345     bool ForwardOrCrossArc(StateId s, const A& arc) {
    346       if (height_[arc.nextstate] + 1 > height_[s])
    347         height_[s] = height_[arc.nextstate] + 1;
    348       return true;
    349     }
    350 
    351     // invoked when state finished (parent is kNoStateId for tree root)
    352     void FinishState(StateId s, StateId parent, const A* parent_arc) {
    353       if (height_[s] == -1) height_[s] = 0;
    354       StateId h = height_[s] +  1;
    355       if (parent >= 0) {
    356         if (h > height_[parent]) height_[parent] = h;
    357         if (h > (StateId)max_height_)     max_height_ = h;
    358       }
    359     }
    360 
    361     // invoked after DFS visit
    362     void FinishVisit() {}
    363 
    364     size_t max_height() const { return max_height_; }
    365 
    366     const vector<StateId>& height() const { return height_; }
    367 
    368     const size_t num_states() const { return num_states_; }
    369 
    370    private:
    371     vector<StateId> height_;
    372     size_t max_height_;
    373     size_t num_states_;
    374   };
    375 
    376   // helper methods
    377  private:
    378   // cluster states according to height (distance to final state)
    379   void Initialize(const Fst<A>& fst) {
    380     // compute height (distance to final state)
    381     HeightVisitor hvisitor;
    382     DfsVisit(fst, &hvisitor);
    383 
    384     // create initial partition based on height
    385     partition_.Initialize(hvisitor.num_states());
    386     partition_.AllocateClasses(hvisitor.max_height() + 1);
    387     const vector<StateId>& hstates = hvisitor.height();
    388     for (size_t s = 0; s < hstates.size(); ++s)
    389       partition_.Add(s, hstates[s]);
    390   }
    391 
    392   // refine states based on arc sort (out degree, arc equivalence)
    393   void Refine(const Fst<A>& fst) {
    394     typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
    395     StateComparator<A> comp(fst, partition_);
    396 
    397     // start with tail (height = 0)
    398     size_t height = partition_.num_classes();
    399     for (size_t h = 0; h < height; ++h) {
    400       EquivalenceMap equiv_classes(comp);
    401 
    402       // sort states within equivalence class
    403       PartitionIterator<StateId> siter(partition_, h);
    404       equiv_classes[siter.Value()] = h;
    405       for (siter.Next(); !siter.Done(); siter.Next()) {
    406         const StateId s = siter.Value();
    407         typename EquivalenceMap::const_iterator it = equiv_classes.find(s);
    408         if (it == equiv_classes.end())
    409           equiv_classes[s] = partition_.AddClass();
    410         else
    411           equiv_classes[s] = it->second;
    412       }
    413 
    414       // create refined partition
    415       for (siter.Reset(); !siter.Done();) {
    416         const StateId s = siter.Value();
    417         const StateId old_class = partition_.class_id(s);
    418         const StateId new_class = equiv_classes[s];
    419 
    420         // a move operation can invalidate the iterator, so
    421         // we first update the iterator to the next element
    422         // before we move the current element out of the list
    423         siter.Next();
    424         if (old_class != new_class)
    425           partition_.Move(s, new_class);
    426       }
    427     }
    428   }
    429 
    430  private:
    431   Partition<StateId> partition_;
    432 };
    433 
    434 
    435 // Given a partition and a mutable fst, merge states of Fst inplace
    436 // (i.e. destructively). Merging works by taking the first state in
    437 // a class of the partition to be the representative state for the class.
    438 // Each arc is then reconnected to this state. All states in the class
    439 // are merged by adding there arcs to the representative state.
    440 template <class A>
    441 void MergeStates(
    442     const Partition<typename A::StateId>& partition, MutableFst<A>* fst) {
    443   typedef typename A::StateId StateId;
    444 
    445   vector<StateId> state_map(partition.num_classes());
    446   for (size_t i = 0; i < (size_t)partition.num_classes(); ++i) {
    447     PartitionIterator<StateId> siter(partition, i);
    448     state_map[i] = siter.Value();  // first state in partition;
    449   }
    450 
    451   // relabel destination states
    452   for (size_t c = 0; c < (size_t)partition.num_classes(); ++c) {
    453     for (PartitionIterator<StateId> siter(partition, c);
    454          !siter.Done(); siter.Next()) {
    455       StateId s = siter.Value();
    456       for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
    457            !aiter.Done(); aiter.Next()) {
    458         A arc = aiter.Value();
    459         arc.nextstate = state_map[partition.class_id(arc.nextstate)];
    460 
    461         if (s == state_map[c])  // first state just set destination
    462           aiter.SetValue(arc);
    463         else
    464           fst->AddArc(state_map[c], arc);
    465       }
    466     }
    467   }
    468   fst->SetStart(state_map[partition.class_id(fst->Start())]);
    469 
    470   Connect(fst);
    471 }
    472 
    473 template <class A>
    474 void AcceptorMinimize(MutableFst<A>* fst) {
    475   typedef typename A::StateId StateId;
    476   if (!(fst->Properties(kAcceptor | kUnweighted, true)))
    477     LOG(FATAL) << "Input Fst is not an unweighted acceptor";
    478 
    479   // connect fst before minimization, handles disconnected states
    480   Connect(fst);
    481   if (fst->NumStates() == 0) return;
    482 
    483   if (fst->Properties(kAcyclic, true)) {
    484     // Acyclic minimization (revuz)
    485     VLOG(2) << "Acyclic Minimization";
    486     AcyclicMinimizer<A> minimizer(*fst);
    487     MergeStates(minimizer.partition(), fst);
    488 
    489   } else {
    490     // Cyclic minimizaton (hopcroft)
    491     VLOG(2) << "Cyclic Minimization";
    492     CyclicMinimizer<A, LifoQueue<StateId> > minimizer(*fst);
    493     MergeStates(minimizer.partition(), fst);
    494   }
    495 
    496   // sort arcs before summing
    497   ArcSort(fst, ILabelCompare<A>());
    498 
    499   // sum in appropriate semiring
    500   ArcSum(fst);
    501 }
    502 
    503 
    504 // In place minimization of unweighted, deterministic acceptors
    505 //
    506 // For acyclic automata we use an algorithm from Dominique Revuz that is
    507 // linear in the number of arcs (edges) in the machine.
    508 //  Complexity = O(E)
    509 //
    510 // For cyclic automata we use the classical hopcroft minimization.
    511 //  Complexity = O(|E|log(|N|)
    512 //
    513 template <class A>
    514 void Minimize(MutableFst<A>* fst, MutableFst<A>* sfst = 0) {
    515   uint64 props = fst->Properties(kAcceptor | kIDeterministic|
    516                                  kWeighted | kUnweighted, true);
    517   if (!(props & kIDeterministic))
    518     LOG(FATAL) << "Input Fst is not deterministic";
    519 
    520   if (!(props & kAcceptor)) {  // weighted transducer
    521     VectorFst< GallicArc<A, STRING_LEFT> > gfst;
    522     Map(*fst, &gfst, ToGallicMapper<A, STRING_LEFT>());
    523     fst->DeleteStates();
    524     gfst.SetProperties(kAcceptor, kAcceptor);
    525     Push(&gfst, REWEIGHT_TO_INITIAL);
    526     Map(&gfst, QuantizeMapper< GallicArc<A, STRING_LEFT> >());
    527     EncodeMapper< GallicArc<A, STRING_LEFT> >
    528       encoder(kEncodeLabels | kEncodeWeights, ENCODE);
    529     Encode(&gfst, &encoder);
    530     AcceptorMinimize(&gfst);
    531     Decode(&gfst, encoder);
    532 
    533     if (sfst == 0) {
    534       FactorWeightFst< GallicArc<A, STRING_LEFT>,
    535         GallicFactor<typename A::Label,
    536         typename A::Weight, STRING_LEFT> > fwfst(gfst);
    537       Map(fwfst, fst, FromGallicMapper<A, STRING_LEFT>());
    538     } else {
    539       sfst->SetOutputSymbols(fst->OutputSymbols());
    540       GallicToNewSymbolsMapper<A, STRING_LEFT> mapper(sfst);
    541       Map(gfst, fst, mapper);
    542       fst->SetOutputSymbols(sfst->InputSymbols());
    543     }
    544   } else if (props & kWeighted) {  // weighted acceptor
    545     Push(fst, REWEIGHT_TO_INITIAL);
    546     Map(fst, QuantizeMapper<A>());
    547     EncodeMapper<A> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
    548     Encode(fst, &encoder);
    549     AcceptorMinimize(fst);
    550     Decode(fst, encoder);
    551   } else {  // unweighted acceptor
    552     AcceptorMinimize(fst);
    553   }
    554 }
    555 
    556 }  // namespace fst
    557 
    558 #endif  // FST_LIB_MINIMIZE_H__
    559