1 // rmepsilon.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen) 16 // 17 // \file 18 // Functions and classes that implemement epsilon-removal. 19 20 #ifndef FST_LIB_RMEPSILON_H__ 21 #define FST_LIB_RMEPSILON_H__ 22 23 #include <ext/hash_map> 24 using __gnu_cxx::hash_map; 25 #include <ext/slist> 26 using __gnu_cxx::slist; 27 28 #include "fst/lib/arcfilter.h" 29 #include "fst/lib/cache.h" 30 #include "fst/lib/connect.h" 31 #include "fst/lib/factor-weight.h" 32 #include "fst/lib/invert.h" 33 #include "fst/lib/map.h" 34 #include "fst/lib/queue.h" 35 #include "fst/lib/shortest-distance.h" 36 #include "fst/lib/topsort.h" 37 38 namespace fst { 39 40 template <class Arc, class Queue> 41 struct RmEpsilonOptions 42 : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > { 43 typedef typename Arc::StateId StateId; 44 45 bool connect; // Connect output 46 47 RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true) 48 : ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> >( 49 q, EpsilonArcFilter<Arc>(), kNoStateId, d), connect(c) {} 50 51 }; 52 53 54 // Computation state of the epsilon-removal algorithm. 55 template <class Arc, class Queue> 56 class RmEpsilonState { 57 public: 58 typedef typename Arc::Label Label; 59 typedef typename Arc::StateId StateId; 60 typedef typename Arc::Weight Weight; 61 62 RmEpsilonState(const Fst<Arc> &fst, 63 vector<Weight> *distance, 64 const RmEpsilonOptions<Arc, Queue> &opts) 65 : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true) { 66 } 67 68 // Compute arcs and final weight for state 's' 69 void Expand(StateId s); 70 71 // Returns arcs of expanded state. 72 vector<Arc> &Arcs() { return arcs_; } 73 74 // Returns final weight of expanded state. 75 const Weight &Final() const { return final_; } 76 77 private: 78 struct Element { 79 Label ilabel; 80 Label olabel; 81 StateId nextstate; 82 83 Element() {} 84 85 Element(Label i, Label o, StateId s) 86 : ilabel(i), olabel(o), nextstate(s) {} 87 }; 88 89 class ElementKey { 90 public: 91 size_t operator()(const Element& e) const { 92 return static_cast<size_t>(e.nextstate); 93 return static_cast<size_t>(e.nextstate + 94 e.ilabel * kPrime0 + 95 e.olabel * kPrime1); 96 } 97 98 private: 99 static const int kPrime0 = 7853; 100 static const int kPrime1 = 7867; 101 }; 102 103 class ElementEqual { 104 public: 105 bool operator()(const Element &e1, const Element &e2) const { 106 return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) 107 && (e1.nextstate == e2.nextstate); 108 } 109 }; 110 111 private: 112 typedef hash_map<Element, pair<StateId, ssize_t>, 113 ElementKey, ElementEqual> ElementMap; 114 115 const Fst<Arc> &fst_; 116 // Distance from state being expanded in epsilon-closure. 117 vector<Weight> *distance_; 118 // Shortest distance algorithm computation state. 119 ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_; 120 // Maps an element 'e' to a pair 'p' corresponding to a position 121 // in the arcs vector of the state being expanded. 'e' corresponds 122 // to the position 'p.second' in the 'arcs_' vector if 'p.first' is 123 // equal to the state being expanded. 124 ElementMap element_map_; 125 EpsilonArcFilter<Arc> eps_filter_; 126 stack<StateId> eps_queue_; // Queue used to visit the epsilon-closure 127 vector<bool> visited_; // '[i] = true' if state 'i' has been visited 128 slist<StateId> visited_states_; // List of visited states 129 vector<Arc> arcs_; // Arcs of state being expanded 130 Weight final_; // Final weight of state being expanded 131 132 void operator=(const RmEpsilonState); // Disallow 133 }; 134 135 136 template <class Arc, class Queue> 137 void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) { 138 sd_state_.ShortestDistance(source); 139 eps_queue_.push(source); 140 final_ = Weight::Zero(); 141 arcs_.clear(); 142 143 while (!eps_queue_.empty()) { 144 StateId state = eps_queue_.top(); 145 eps_queue_.pop(); 146 147 while ((StateId)visited_.size() <= state) visited_.push_back(false); 148 visited_[state] = true; 149 visited_states_.push_front(state); 150 151 for (ArcIterator< Fst<Arc> > ait(fst_, state); 152 !ait.Done(); 153 ait.Next()) { 154 Arc arc = ait.Value(); 155 arc.weight = Times((*distance_)[state], arc.weight); 156 157 if (eps_filter_(arc)) { 158 while ((StateId)visited_.size() <= arc.nextstate) 159 visited_.push_back(false); 160 if (!visited_[arc.nextstate]) 161 eps_queue_.push(arc.nextstate); 162 } else { 163 Element element(arc.ilabel, arc.olabel, arc.nextstate); 164 typename ElementMap::iterator it = element_map_.find(element); 165 if (it == element_map_.end()) { 166 element_map_.insert( 167 pair<Element, pair<StateId, ssize_t> > 168 (element, pair<StateId, ssize_t>(source, arcs_.size()))); 169 arcs_.push_back(arc); 170 } else { 171 if (((*it).second).first == source) { 172 Weight &w = arcs_[((*it).second).second].weight; 173 w = Plus(w, arc.weight); 174 } else { 175 ((*it).second).first = source; 176 ((*it).second).second = arcs_.size(); 177 arcs_.push_back(arc); 178 } 179 } 180 } 181 } 182 final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state))); 183 } 184 185 while (!visited_states_.empty()) { 186 visited_[visited_states_.front()] = false; 187 visited_states_.pop_front(); 188 } 189 } 190 191 192 // Removes epsilon-transitions (when both the input and output label 193 // are an epsilon) from a transducer. The result will be an equivalent 194 // FST that has no such epsilon transitions. This version modifies 195 // its input. It allows fine control via the options argument; see 196 // below for a simpler interface. 197 // 198 // The vector 'distance' will be used to hold the shortest distances 199 // during the epsilon-closure computation. The state queue discipline 200 // and convergence delta are taken in the options argument. 201 template <class Arc, class Queue> 202 void RmEpsilon(MutableFst<Arc> *fst, 203 vector<typename Arc::Weight> *distance, 204 const RmEpsilonOptions<Arc, Queue> &opts) { 205 typedef typename Arc::StateId StateId; 206 typedef typename Arc::Weight Weight; 207 typedef typename Arc::Label Label; 208 209 // States sorted in topological order when (acyclic) or generic 210 // topological order (cyclic). 211 vector<StateId> states; 212 213 if (fst->Properties(kTopSorted, false) & kTopSorted) { 214 for (StateId i = 0; i < (StateId)fst->NumStates(); i++) 215 states.push_back(i); 216 } else if (fst->Properties(kAcyclic, false) & kAcyclic) { 217 vector<StateId> order; 218 bool acyclic; 219 TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic); 220 DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>()); 221 if (!acyclic) 222 LOG(FATAL) << "RmEpsilon: not acyclic though property bit is set"; 223 states.resize(order.size()); 224 for (StateId i = 0; i < (StateId)order.size(); i++) 225 states[order[i]] = i; 226 } else { 227 uint64 props; 228 vector<StateId> scc; 229 SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props); 230 DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>()); 231 vector<StateId> first(scc.size(), kNoStateId); 232 vector<StateId> next(scc.size(), kNoStateId); 233 for (StateId i = 0; i < (StateId)scc.size(); i++) { 234 if (first[scc[i]] != kNoStateId) 235 next[i] = first[scc[i]]; 236 first[scc[i]] = i; 237 } 238 for (StateId i = 0; i < (StateId)first.size(); i++) 239 for (StateId j = first[i]; j != kNoStateId; j = next[j]) 240 states.push_back(j); 241 } 242 243 RmEpsilonState<Arc, Queue> 244 rmeps_state(*fst, distance, opts); 245 246 while (!states.empty()) { 247 StateId state = states.back(); 248 states.pop_back(); 249 rmeps_state.Expand(state); 250 fst->SetFinal(state, rmeps_state.Final()); 251 fst->DeleteArcs(state); 252 vector<Arc> &arcs = rmeps_state.Arcs(); 253 while (!arcs.empty()) { 254 fst->AddArc(state, arcs.back()); 255 arcs.pop_back(); 256 } 257 } 258 259 fst->SetProperties(RmEpsilonProperties( 260 fst->Properties(kFstProperties, false)), 261 kFstProperties); 262 263 if (opts.connect) 264 Connect(fst); 265 } 266 267 268 // Removes epsilon-transitions (when both the input and output label 269 // are an epsilon) from a transducer. The result will be an equivalent 270 // FST that has no such epsilon transitions. This version modifies its 271 // input. It has a simplified interface; see above for a version that 272 // allows finer control. 273 // 274 // Complexity: 275 // - Time: 276 // - Unweighted: O(V2 + V E) 277 // - Acyclic: O(V2 + V E) 278 // - Tropical semiring: O(V2 log V + V E) 279 // - General: exponential 280 // - Space: O(V E) 281 // where V = # of states visited, E = # of arcs. 282 // 283 // References: 284 // - Mehryar Mohri. Generic Epsilon-Removal and Input 285 // Epsilon-Normalization Algorithms for Weighted Transducers, 286 // "International Journal of Computer Science", 13(1):129-143 (2002). 287 template <class Arc> 288 void RmEpsilon(MutableFst<Arc> *fst, bool connect = true) { 289 typedef typename Arc::StateId StateId; 290 typedef typename Arc::Weight Weight; 291 typedef typename Arc::Label Label; 292 293 vector<Weight> distance; 294 AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>()); 295 RmEpsilonOptions<Arc, AutoQueue<StateId> > 296 opts(&state_queue, kDelta, connect); 297 298 RmEpsilon(fst, &distance, opts); 299 } 300 301 302 struct RmEpsilonFstOptions : CacheOptions { 303 float delta; 304 305 RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta) 306 : CacheOptions(opts), delta(delta) {} 307 308 explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {} 309 }; 310 311 312 // Implementation of delayed RmEpsilonFst. 313 template <class A> 314 class RmEpsilonFstImpl : public CacheImpl<A> { 315 public: 316 using FstImpl<A>::SetType; 317 using FstImpl<A>::SetProperties; 318 using FstImpl<A>::Properties; 319 using FstImpl<A>::SetInputSymbols; 320 using FstImpl<A>::SetOutputSymbols; 321 322 using CacheBaseImpl< CacheState<A> >::HasStart; 323 using CacheBaseImpl< CacheState<A> >::HasFinal; 324 using CacheBaseImpl< CacheState<A> >::HasArcs; 325 326 typedef typename A::Label Label; 327 typedef typename A::Weight Weight; 328 typedef typename A::StateId StateId; 329 typedef CacheState<A> State; 330 331 RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts) 332 : CacheImpl<A>(opts), 333 fst_(fst.Copy()), 334 rmeps_state_( 335 *fst_, 336 &distance_, 337 RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, opts.delta, 338 false) 339 ) { 340 SetType("rmepsilon"); 341 uint64 props = fst.Properties(kFstProperties, false); 342 SetProperties(RmEpsilonProperties(props, true), kCopyProperties); 343 } 344 345 ~RmEpsilonFstImpl() { 346 delete fst_; 347 } 348 349 StateId Start() { 350 if (!HasStart()) { 351 SetStart(fst_->Start()); 352 } 353 return CacheImpl<A>::Start(); 354 } 355 356 Weight Final(StateId s) { 357 if (!HasFinal(s)) { 358 Expand(s); 359 } 360 return CacheImpl<A>::Final(s); 361 } 362 363 size_t NumArcs(StateId s) { 364 if (!HasArcs(s)) 365 Expand(s); 366 return CacheImpl<A>::NumArcs(s); 367 } 368 369 size_t NumInputEpsilons(StateId s) { 370 if (!HasArcs(s)) 371 Expand(s); 372 return CacheImpl<A>::NumInputEpsilons(s); 373 } 374 375 size_t NumOutputEpsilons(StateId s) { 376 if (!HasArcs(s)) 377 Expand(s); 378 return CacheImpl<A>::NumOutputEpsilons(s); 379 } 380 381 void InitArcIterator(StateId s, ArcIteratorData<A> *data) { 382 if (!HasArcs(s)) 383 Expand(s); 384 CacheImpl<A>::InitArcIterator(s, data); 385 } 386 387 void Expand(StateId s) { 388 rmeps_state_.Expand(s); 389 SetFinal(s, rmeps_state_.Final()); 390 vector<A> &arcs = rmeps_state_.Arcs(); 391 while (!arcs.empty()) { 392 AddArc(s, arcs.back()); 393 arcs.pop_back(); 394 } 395 SetArcs(s); 396 } 397 398 private: 399 const Fst<A> *fst_; 400 vector<Weight> distance_; 401 FifoQueue<StateId> queue_; 402 RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_; 403 404 DISALLOW_EVIL_CONSTRUCTORS(RmEpsilonFstImpl); 405 }; 406 407 408 // Removes epsilon-transitions (when both the input and output label 409 // are an epsilon) from a transducer. The result will be an equivalent 410 // FST that has no such epsilon transitions. This version is a 411 // delayed Fst. 412 // 413 // Complexity: 414 // - Time: 415 // - Unweighted: O(v^2 + v e) 416 // - General: exponential 417 // - Space: O(v e) 418 // where v = # of states visited, e = # of arcs visited. Constant time 419 // to visit an input state or arc is assumed and exclusive of caching. 420 // 421 // References: 422 // - Mehryar Mohri. Generic Epsilon-Removal and Input 423 // Epsilon-Normalization Algorithms for Weighted Transducers, 424 // "International Journal of Computer Science", 13(1):129-143 (2002). 425 template <class A> 426 class RmEpsilonFst : public Fst<A> { 427 public: 428 friend class ArcIterator< RmEpsilonFst<A> >; 429 friend class CacheStateIterator< RmEpsilonFst<A> >; 430 friend class CacheArcIterator< RmEpsilonFst<A> >; 431 432 typedef A Arc; 433 typedef typename A::Weight Weight; 434 typedef typename A::StateId StateId; 435 typedef CacheState<A> State; 436 437 RmEpsilonFst(const Fst<A> &fst) 438 : impl_(new RmEpsilonFstImpl<A>(fst, RmEpsilonFstOptions())) {} 439 440 RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts) 441 : impl_(new RmEpsilonFstImpl<A>(fst, opts)) {} 442 443 explicit RmEpsilonFst(const RmEpsilonFst<A> &fst) : impl_(fst.impl_) { 444 impl_->IncrRefCount(); 445 } 446 447 virtual ~RmEpsilonFst() { if (!impl_->DecrRefCount()) delete impl_; } 448 449 virtual StateId Start() const { return impl_->Start(); } 450 451 virtual Weight Final(StateId s) const { return impl_->Final(s); } 452 453 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); } 454 455 virtual size_t NumInputEpsilons(StateId s) const { 456 return impl_->NumInputEpsilons(s); 457 } 458 459 virtual size_t NumOutputEpsilons(StateId s) const { 460 return impl_->NumOutputEpsilons(s); 461 } 462 463 virtual uint64 Properties(uint64 mask, bool test) const { 464 if (test) { 465 uint64 known, test = TestProperties(*this, mask, &known); 466 impl_->SetProperties(test, known); 467 return test & mask; 468 } else { 469 return impl_->Properties(mask); 470 } 471 } 472 473 virtual const string& Type() const { return impl_->Type(); } 474 475 virtual RmEpsilonFst<A> *Copy() const { 476 return new RmEpsilonFst<A>(*this); 477 } 478 479 virtual const SymbolTable* InputSymbols() const { 480 return impl_->InputSymbols(); 481 } 482 483 virtual const SymbolTable* OutputSymbols() const { 484 return impl_->OutputSymbols(); 485 } 486 487 virtual inline void InitStateIterator(StateIteratorData<A> *data) const; 488 489 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { 490 impl_->InitArcIterator(s, data); 491 } 492 493 protected: 494 RmEpsilonFstImpl<A> *Impl() { return impl_; } 495 496 private: 497 RmEpsilonFstImpl<A> *impl_; 498 499 void operator=(const RmEpsilonFst<A> &fst); // disallow 500 }; 501 502 503 // Specialization for RmEpsilonFst. 504 template<class A> 505 class StateIterator< RmEpsilonFst<A> > 506 : public CacheStateIterator< RmEpsilonFst<A> > { 507 public: 508 explicit StateIterator(const RmEpsilonFst<A> &fst) 509 : CacheStateIterator< RmEpsilonFst<A> >(fst) {} 510 }; 511 512 513 // Specialization for RmEpsilonFst. 514 template <class A> 515 class ArcIterator< RmEpsilonFst<A> > 516 : public CacheArcIterator< RmEpsilonFst<A> > { 517 public: 518 typedef typename A::StateId StateId; 519 520 ArcIterator(const RmEpsilonFst<A> &fst, StateId s) 521 : CacheArcIterator< RmEpsilonFst<A> >(fst, s) { 522 if (!fst.impl_->HasArcs(s)) 523 fst.impl_->Expand(s); 524 } 525 526 private: 527 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator); 528 }; 529 530 531 template <class A> inline 532 void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const { 533 data->base = new StateIterator< RmEpsilonFst<A> >(*this); 534 } 535 536 537 // Useful alias when using StdArc. 538 typedef RmEpsilonFst<StdArc> StdRmEpsilonFst; 539 540 } // namespace fst 541 542 #endif // FST_LIB_RMEPSILON_H__ 543