Home | History | Annotate | Download | only in fst
      1 // connect.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Classes and functions to remove unsuccessful paths from an Fst.
     20 
     21 #ifndef FST_LIB_CONNECT_H__
     22 #define FST_LIB_CONNECT_H__
     23 
     24 #include <vector>
     25 using std::vector;
     26 
     27 #include <fst/dfs-visit.h>
     28 #include <fst/union-find.h>
     29 #include <fst/mutable-fst.h>
     30 
     31 
     32 namespace fst {
     33 
     34 // Finds and returns connected components. Use with Visit().
     35 template <class A>
     36 class CcVisitor {
     37  public:
     38   typedef A Arc;
     39   typedef typename Arc::Weight Weight;
     40   typedef typename A::StateId StateId;
     41 
     42   // cc[i]: connected component number for state i.
     43   CcVisitor(vector<StateId> *cc)
     44       : comps_(new UnionFind<StateId>(0, kNoStateId)),
     45         cc_(cc),
     46         nstates_(0) { }
     47 
     48   // comps: connected components equiv classes.
     49   CcVisitor(UnionFind<StateId> *comps)
     50       : comps_(comps),
     51         cc_(0),
     52         nstates_(0) { }
     53 
     54   ~CcVisitor() {
     55     if (cc_)  // own comps_?
     56       delete comps_;
     57   }
     58 
     59   void InitVisit(const Fst<A> &fst) { }
     60 
     61   bool InitState(StateId s, StateId root) {
     62     ++nstates_;
     63     if (comps_->FindSet(s) == kNoStateId)
     64       comps_->MakeSet(s);
     65     return true;
     66   }
     67 
     68   bool WhiteArc(StateId s, const A &arc) {
     69     comps_->MakeSet(arc.nextstate);
     70     comps_->Union(s, arc.nextstate);
     71     return true;
     72   }
     73 
     74   bool GreyArc(StateId s, const A &arc) {
     75     comps_->Union(s, arc.nextstate);
     76     return true;
     77   }
     78 
     79   bool BlackArc(StateId s, const A &arc) {
     80     comps_->Union(s, arc.nextstate);
     81     return true;
     82   }
     83 
     84   void FinishState(StateId s) { }
     85 
     86   void FinishVisit() {
     87     if (cc_)
     88       GetCcVector(cc_);
     89   }
     90 
     91   // cc[i]: connected component number for state i.
     92   // Returns number of components.
     93   int GetCcVector(vector<StateId> *cc) {
     94     cc->clear();
     95     cc->resize(nstates_, kNoStateId);
     96     StateId ncomp = 0;
     97     for (StateId i = 0; i < nstates_; ++i) {
     98       StateId rep = comps_->FindSet(i);
     99       StateId &comp = (*cc)[rep];
    100       if (comp == kNoStateId) {
    101         comp = ncomp;
    102         ++ncomp;
    103       }
    104       (*cc)[i] = comp;
    105     }
    106     return ncomp;
    107   }
    108 
    109  private:
    110   UnionFind<StateId> *comps_;   // Components
    111   vector<StateId> *cc_;         // State's cc number
    112   StateId nstates_;             // State count
    113 };
    114 
    115 
    116 // Finds and returns strongly-connected components, accessible and
    117 // coaccessible states and related properties. Uses Tarjan's single
    118 // DFS SCC algorithm (see Aho, et al, "Design and Analysis of Computer
    119 // Algorithms", 189pp). Use with DfsVisit();
    120 template <class A>
    121 class SccVisitor {
    122  public:
    123   typedef A Arc;
    124   typedef typename A::Weight Weight;
    125   typedef typename A::StateId StateId;
    126 
    127   // scc[i]: strongly-connected component number for state i.
    128   //   SCC numbers will be in topological order for acyclic input.
    129   // access[i]: accessibility of state i.
    130   // coaccess[i]: coaccessibility of state i.
    131   // Any of above can be NULL.
    132   // props: related property bits (cyclicity, initial cyclicity,
    133   //   accessibility, coaccessibility) set/cleared (o.w. unchanged).
    134   SccVisitor(vector<StateId> *scc, vector<bool> *access,
    135              vector<bool> *coaccess, uint64 *props)
    136       : scc_(scc), access_(access), coaccess_(coaccess), props_(props) {}
    137   SccVisitor(uint64 *props)
    138       : scc_(0), access_(0), coaccess_(0), props_(props) {}
    139 
    140   void InitVisit(const Fst<A> &fst);
    141 
    142   bool InitState(StateId s, StateId root);
    143 
    144   bool TreeArc(StateId s, const A &arc) { return true; }
    145 
    146   bool BackArc(StateId s, const A &arc) {
    147     StateId t = arc.nextstate;
    148     if ((*dfnumber_)[t] < (*lowlink_)[s])
    149       (*lowlink_)[s] = (*dfnumber_)[t];
    150     if ((*coaccess_)[t])
    151       (*coaccess_)[s] = true;
    152     *props_ |= kCyclic;
    153     *props_ &= ~kAcyclic;
    154     if (arc.nextstate == start_) {
    155       *props_ |= kInitialCyclic;
    156       *props_ &= ~kInitialAcyclic;
    157     }
    158     return true;
    159   }
    160 
    161   bool ForwardOrCrossArc(StateId s, const A &arc) {
    162     StateId t = arc.nextstate;
    163     if ((*dfnumber_)[t] < (*dfnumber_)[s] /* cross edge */ &&
    164         (*onstack_)[t] && (*dfnumber_)[t] < (*lowlink_)[s])
    165       (*lowlink_)[s] = (*dfnumber_)[t];
    166     if ((*coaccess_)[t])
    167       (*coaccess_)[s] = true;
    168     return true;
    169   }
    170 
    171   void FinishState(StateId s, StateId p, const A *);
    172 
    173   void FinishVisit() {
    174     // Numbers SCC's in topological order when acyclic.
    175     if (scc_)
    176       for (StateId i = 0; i < scc_->size(); ++i)
    177         (*scc_)[i] = nscc_ - 1 - (*scc_)[i];
    178     if (coaccess_internal_)
    179       delete coaccess_;
    180     delete dfnumber_;
    181     delete lowlink_;
    182     delete onstack_;
    183     delete scc_stack_;
    184   }
    185 
    186  private:
    187   vector<StateId> *scc_;        // State's scc number
    188   vector<bool> *access_;        // State's accessibility
    189   vector<bool> *coaccess_;      // State's coaccessibility
    190   uint64 *props_;
    191   const Fst<A> *fst_;
    192   StateId start_;
    193   StateId nstates_;             // State count
    194   StateId nscc_;                // SCC count
    195   bool coaccess_internal_;
    196   vector<StateId> *dfnumber_;   // state discovery times
    197   vector<StateId> *lowlink_;    // lowlink[s] == dfnumber[s] => SCC root
    198   vector<bool> *onstack_;       // is a state on the SCC stack
    199   vector<StateId> *scc_stack_;  // SCC stack (w/ random access)
    200 };
    201 
    202 template <class A> inline
    203 void SccVisitor<A>::InitVisit(const Fst<A> &fst) {
    204   if (scc_)
    205     scc_->clear();
    206   if (access_)
    207     access_->clear();
    208   if (coaccess_) {
    209     coaccess_->clear();
    210     coaccess_internal_ = false;
    211   } else {
    212     coaccess_ = new vector<bool>;
    213     coaccess_internal_ = true;
    214   }
    215   *props_ |= kAcyclic | kInitialAcyclic | kAccessible | kCoAccessible;
    216   *props_ &= ~(kCyclic | kInitialCyclic | kNotAccessible | kNotCoAccessible);
    217   fst_ = &fst;
    218   start_ = fst.Start();
    219   nstates_ = 0;
    220   nscc_ = 0;
    221   dfnumber_ = new vector<StateId>;
    222   lowlink_ = new vector<StateId>;
    223   onstack_ = new vector<bool>;
    224   scc_stack_ = new vector<StateId>;
    225 }
    226 
    227 template <class A> inline
    228 bool SccVisitor<A>::InitState(StateId s, StateId root) {
    229   scc_stack_->push_back(s);
    230   while (dfnumber_->size() <= s) {
    231     if (scc_)
    232       scc_->push_back(-1);
    233     if (access_)
    234       access_->push_back(false);
    235     coaccess_->push_back(false);
    236     dfnumber_->push_back(-1);
    237     lowlink_->push_back(-1);
    238     onstack_->push_back(false);
    239   }
    240   (*dfnumber_)[s] = nstates_;
    241   (*lowlink_)[s] = nstates_;
    242   (*onstack_)[s] = true;
    243   if (root == start_) {
    244     if (access_)
    245       (*access_)[s] = true;
    246   } else {
    247     if (access_)
    248       (*access_)[s] = false;
    249     *props_ |= kNotAccessible;
    250     *props_ &= ~kAccessible;
    251   }
    252   ++nstates_;
    253   return true;
    254 }
    255 
    256 template <class A> inline
    257 void SccVisitor<A>::FinishState(StateId s, StateId p, const A *) {
    258   if (fst_->Final(s) != Weight::Zero())
    259     (*coaccess_)[s] = true;
    260   if ((*dfnumber_)[s] == (*lowlink_)[s]) {  // root of new SCC
    261     bool scc_coaccess = false;
    262     size_t i = scc_stack_->size();
    263     StateId t;
    264     do {
    265       t = (*scc_stack_)[--i];
    266       if ((*coaccess_)[t])
    267         scc_coaccess = true;
    268     } while (s != t);
    269     do {
    270       t = scc_stack_->back();
    271       if (scc_)
    272         (*scc_)[t] = nscc_;
    273       if (scc_coaccess)
    274         (*coaccess_)[t] = true;
    275       (*onstack_)[t] = false;
    276       scc_stack_->pop_back();
    277     } while (s != t);
    278     if (!scc_coaccess) {
    279       *props_ |= kNotCoAccessible;
    280       *props_ &= ~kCoAccessible;
    281     }
    282     ++nscc_;
    283   }
    284   if (p != kNoStateId) {
    285     if ((*coaccess_)[s])
    286       (*coaccess_)[p] = true;
    287     if ((*lowlink_)[s] < (*lowlink_)[p])
    288       (*lowlink_)[p] = (*lowlink_)[s];
    289   }
    290 }
    291 
    292 
    293 // Trims an FST, removing states and arcs that are not on successful
    294 // paths. This version modifies its input.
    295 //
    296 // Complexity:
    297 // - Time:  O(V + E)
    298 // - Space: O(V + E)
    299 // where V = # of states and E = # of arcs.
    300 template<class Arc>
    301 void Connect(MutableFst<Arc> *fst) {
    302   typedef typename Arc::StateId StateId;
    303 
    304   vector<bool> access;
    305   vector<bool> coaccess;
    306   uint64 props = 0;
    307   SccVisitor<Arc> scc_visitor(0, &access, &coaccess, &props);
    308   DfsVisit(*fst, &scc_visitor);
    309   vector<StateId> dstates;
    310   for (StateId s = 0; s < access.size(); ++s)
    311     if (!access[s] || !coaccess[s])
    312       dstates.push_back(s);
    313   fst->DeleteStates(dstates);
    314   fst->SetProperties(kAccessible | kCoAccessible, kAccessible | kCoAccessible);
    315 }
    316 
    317 }  // namespace fst
    318 
    319 #endif  // FST_LIB_CONNECT_H__
    320