1 // connect.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley (at) google.com (Michael Riley) 17 // 18 // \file 19 // Classes and functions to remove unsuccessful paths from an Fst. 20 21 #ifndef FST_LIB_CONNECT_H__ 22 #define FST_LIB_CONNECT_H__ 23 24 #include <vector> 25 using std::vector; 26 27 #include <fst/dfs-visit.h> 28 #include <fst/union-find.h> 29 #include <fst/mutable-fst.h> 30 31 32 namespace fst { 33 34 // Finds and returns connected components. Use with Visit(). 35 template <class A> 36 class CcVisitor { 37 public: 38 typedef A Arc; 39 typedef typename Arc::Weight Weight; 40 typedef typename A::StateId StateId; 41 42 // cc[i]: connected component number for state i. 43 CcVisitor(vector<StateId> *cc) 44 : comps_(new UnionFind<StateId>(0, kNoStateId)), 45 cc_(cc), 46 nstates_(0) { } 47 48 // comps: connected components equiv classes. 49 CcVisitor(UnionFind<StateId> *comps) 50 : comps_(comps), 51 cc_(0), 52 nstates_(0) { } 53 54 ~CcVisitor() { 55 if (cc_) // own comps_? 56 delete comps_; 57 } 58 59 void InitVisit(const Fst<A> &fst) { } 60 61 bool InitState(StateId s, StateId root) { 62 ++nstates_; 63 if (comps_->FindSet(s) == kNoStateId) 64 comps_->MakeSet(s); 65 return true; 66 } 67 68 bool WhiteArc(StateId s, const A &arc) { 69 comps_->MakeSet(arc.nextstate); 70 comps_->Union(s, arc.nextstate); 71 return true; 72 } 73 74 bool GreyArc(StateId s, const A &arc) { 75 comps_->Union(s, arc.nextstate); 76 return true; 77 } 78 79 bool BlackArc(StateId s, const A &arc) { 80 comps_->Union(s, arc.nextstate); 81 return true; 82 } 83 84 void FinishState(StateId s) { } 85 86 void FinishVisit() { 87 if (cc_) 88 GetCcVector(cc_); 89 } 90 91 // cc[i]: connected component number for state i. 92 // Returns number of components. 93 int GetCcVector(vector<StateId> *cc) { 94 cc->clear(); 95 cc->resize(nstates_, kNoStateId); 96 StateId ncomp = 0; 97 for (StateId i = 0; i < nstates_; ++i) { 98 StateId rep = comps_->FindSet(i); 99 StateId &comp = (*cc)[rep]; 100 if (comp == kNoStateId) { 101 comp = ncomp; 102 ++ncomp; 103 } 104 (*cc)[i] = comp; 105 } 106 return ncomp; 107 } 108 109 private: 110 UnionFind<StateId> *comps_; // Components 111 vector<StateId> *cc_; // State's cc number 112 StateId nstates_; // State count 113 }; 114 115 116 // Finds and returns strongly-connected components, accessible and 117 // coaccessible states and related properties. Uses Tarjan's single 118 // DFS SCC algorithm (see Aho, et al, "Design and Analysis of Computer 119 // Algorithms", 189pp). Use with DfsVisit(); 120 template <class A> 121 class SccVisitor { 122 public: 123 typedef A Arc; 124 typedef typename A::Weight Weight; 125 typedef typename A::StateId StateId; 126 127 // scc[i]: strongly-connected component number for state i. 128 // SCC numbers will be in topological order for acyclic input. 129 // access[i]: accessibility of state i. 130 // coaccess[i]: coaccessibility of state i. 131 // Any of above can be NULL. 132 // props: related property bits (cyclicity, initial cyclicity, 133 // accessibility, coaccessibility) set/cleared (o.w. unchanged). 134 SccVisitor(vector<StateId> *scc, vector<bool> *access, 135 vector<bool> *coaccess, uint64 *props) 136 : scc_(scc), access_(access), coaccess_(coaccess), props_(props) {} 137 SccVisitor(uint64 *props) 138 : scc_(0), access_(0), coaccess_(0), props_(props) {} 139 140 void InitVisit(const Fst<A> &fst); 141 142 bool InitState(StateId s, StateId root); 143 144 bool TreeArc(StateId s, const A &arc) { return true; } 145 146 bool BackArc(StateId s, const A &arc) { 147 StateId t = arc.nextstate; 148 if ((*dfnumber_)[t] < (*lowlink_)[s]) 149 (*lowlink_)[s] = (*dfnumber_)[t]; 150 if ((*coaccess_)[t]) 151 (*coaccess_)[s] = true; 152 *props_ |= kCyclic; 153 *props_ &= ~kAcyclic; 154 if (arc.nextstate == start_) { 155 *props_ |= kInitialCyclic; 156 *props_ &= ~kInitialAcyclic; 157 } 158 return true; 159 } 160 161 bool ForwardOrCrossArc(StateId s, const A &arc) { 162 StateId t = arc.nextstate; 163 if ((*dfnumber_)[t] < (*dfnumber_)[s] /* cross edge */ && 164 (*onstack_)[t] && (*dfnumber_)[t] < (*lowlink_)[s]) 165 (*lowlink_)[s] = (*dfnumber_)[t]; 166 if ((*coaccess_)[t]) 167 (*coaccess_)[s] = true; 168 return true; 169 } 170 171 void FinishState(StateId s, StateId p, const A *); 172 173 void FinishVisit() { 174 // Numbers SCC's in topological order when acyclic. 175 if (scc_) 176 for (StateId i = 0; i < scc_->size(); ++i) 177 (*scc_)[i] = nscc_ - 1 - (*scc_)[i]; 178 if (coaccess_internal_) 179 delete coaccess_; 180 delete dfnumber_; 181 delete lowlink_; 182 delete onstack_; 183 delete scc_stack_; 184 } 185 186 private: 187 vector<StateId> *scc_; // State's scc number 188 vector<bool> *access_; // State's accessibility 189 vector<bool> *coaccess_; // State's coaccessibility 190 uint64 *props_; 191 const Fst<A> *fst_; 192 StateId start_; 193 StateId nstates_; // State count 194 StateId nscc_; // SCC count 195 bool coaccess_internal_; 196 vector<StateId> *dfnumber_; // state discovery times 197 vector<StateId> *lowlink_; // lowlink[s] == dfnumber[s] => SCC root 198 vector<bool> *onstack_; // is a state on the SCC stack 199 vector<StateId> *scc_stack_; // SCC stack (w/ random access) 200 }; 201 202 template <class A> inline 203 void SccVisitor<A>::InitVisit(const Fst<A> &fst) { 204 if (scc_) 205 scc_->clear(); 206 if (access_) 207 access_->clear(); 208 if (coaccess_) { 209 coaccess_->clear(); 210 coaccess_internal_ = false; 211 } else { 212 coaccess_ = new vector<bool>; 213 coaccess_internal_ = true; 214 } 215 *props_ |= kAcyclic | kInitialAcyclic | kAccessible | kCoAccessible; 216 *props_ &= ~(kCyclic | kInitialCyclic | kNotAccessible | kNotCoAccessible); 217 fst_ = &fst; 218 start_ = fst.Start(); 219 nstates_ = 0; 220 nscc_ = 0; 221 dfnumber_ = new vector<StateId>; 222 lowlink_ = new vector<StateId>; 223 onstack_ = new vector<bool>; 224 scc_stack_ = new vector<StateId>; 225 } 226 227 template <class A> inline 228 bool SccVisitor<A>::InitState(StateId s, StateId root) { 229 scc_stack_->push_back(s); 230 while (dfnumber_->size() <= s) { 231 if (scc_) 232 scc_->push_back(-1); 233 if (access_) 234 access_->push_back(false); 235 coaccess_->push_back(false); 236 dfnumber_->push_back(-1); 237 lowlink_->push_back(-1); 238 onstack_->push_back(false); 239 } 240 (*dfnumber_)[s] = nstates_; 241 (*lowlink_)[s] = nstates_; 242 (*onstack_)[s] = true; 243 if (root == start_) { 244 if (access_) 245 (*access_)[s] = true; 246 } else { 247 if (access_) 248 (*access_)[s] = false; 249 *props_ |= kNotAccessible; 250 *props_ &= ~kAccessible; 251 } 252 ++nstates_; 253 return true; 254 } 255 256 template <class A> inline 257 void SccVisitor<A>::FinishState(StateId s, StateId p, const A *) { 258 if (fst_->Final(s) != Weight::Zero()) 259 (*coaccess_)[s] = true; 260 if ((*dfnumber_)[s] == (*lowlink_)[s]) { // root of new SCC 261 bool scc_coaccess = false; 262 size_t i = scc_stack_->size(); 263 StateId t; 264 do { 265 t = (*scc_stack_)[--i]; 266 if ((*coaccess_)[t]) 267 scc_coaccess = true; 268 } while (s != t); 269 do { 270 t = scc_stack_->back(); 271 if (scc_) 272 (*scc_)[t] = nscc_; 273 if (scc_coaccess) 274 (*coaccess_)[t] = true; 275 (*onstack_)[t] = false; 276 scc_stack_->pop_back(); 277 } while (s != t); 278 if (!scc_coaccess) { 279 *props_ |= kNotCoAccessible; 280 *props_ &= ~kCoAccessible; 281 } 282 ++nscc_; 283 } 284 if (p != kNoStateId) { 285 if ((*coaccess_)[s]) 286 (*coaccess_)[p] = true; 287 if ((*lowlink_)[s] < (*lowlink_)[p]) 288 (*lowlink_)[p] = (*lowlink_)[s]; 289 } 290 } 291 292 293 // Trims an FST, removing states and arcs that are not on successful 294 // paths. This version modifies its input. 295 // 296 // Complexity: 297 // - Time: O(V + E) 298 // - Space: O(V + E) 299 // where V = # of states and E = # of arcs. 300 template<class Arc> 301 void Connect(MutableFst<Arc> *fst) { 302 typedef typename Arc::StateId StateId; 303 304 vector<bool> access; 305 vector<bool> coaccess; 306 uint64 props = 0; 307 SccVisitor<Arc> scc_visitor(0, &access, &coaccess, &props); 308 DfsVisit(*fst, &scc_visitor); 309 vector<StateId> dstates; 310 for (StateId s = 0; s < access.size(); ++s) 311 if (!access[s] || !coaccess[s]) 312 dstates.push_back(s); 313 fst->DeleteStates(dstates); 314 fst->SetProperties(kAccessible | kCoAccessible, kAccessible | kCoAccessible); 315 } 316 317 } // namespace fst 318 319 #endif // FST_LIB_CONNECT_H__ 320