1 // rmepsilon.h 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Author: allauzen (at) cs.nyu.edu (Cyril Allauzen) 16 // 17 // \file 18 // Functions and classes that implemement epsilon-removal. 19 20 #ifndef FST_LIB_RMEPSILON_H__ 21 #define FST_LIB_RMEPSILON_H__ 22 23 #include <unordered_map> 24 #include <forward_list> 25 26 #include "fst/lib/arcfilter.h" 27 #include "fst/lib/cache.h" 28 #include "fst/lib/connect.h" 29 #include "fst/lib/factor-weight.h" 30 #include "fst/lib/invert.h" 31 #include "fst/lib/map.h" 32 #include "fst/lib/queue.h" 33 #include "fst/lib/shortest-distance.h" 34 #include "fst/lib/topsort.h" 35 36 namespace fst { 37 38 template <class Arc, class Queue> 39 struct RmEpsilonOptions 40 : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > { 41 typedef typename Arc::StateId StateId; 42 43 bool connect; // Connect output 44 45 RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true) 46 : ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> >( 47 q, EpsilonArcFilter<Arc>(), kNoStateId, d), connect(c) {} 48 49 }; 50 51 52 // Computation state of the epsilon-removal algorithm. 53 template <class Arc, class Queue> 54 class RmEpsilonState { 55 public: 56 typedef typename Arc::Label Label; 57 typedef typename Arc::StateId StateId; 58 typedef typename Arc::Weight Weight; 59 60 RmEpsilonState(const Fst<Arc> &fst, 61 vector<Weight> *distance, 62 const RmEpsilonOptions<Arc, Queue> &opts) 63 : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true) { 64 } 65 66 // Compute arcs and final weight for state 's' 67 void Expand(StateId s); 68 69 // Returns arcs of expanded state. 70 vector<Arc> &Arcs() { return arcs_; } 71 72 // Returns final weight of expanded state. 73 const Weight &Final() const { return final_; } 74 75 private: 76 struct Element { 77 Label ilabel; 78 Label olabel; 79 StateId nextstate; 80 81 Element() {} 82 83 Element(Label i, Label o, StateId s) 84 : ilabel(i), olabel(o), nextstate(s) {} 85 }; 86 87 class ElementKey { 88 public: 89 size_t operator()(const Element& e) const { 90 return static_cast<size_t>(e.nextstate); 91 return static_cast<size_t>(e.nextstate + 92 e.ilabel * kPrime0 + 93 e.olabel * kPrime1); 94 } 95 96 private: 97 static const int kPrime0 = 7853; 98 static const int kPrime1 = 7867; 99 }; 100 101 class ElementEqual { 102 public: 103 bool operator()(const Element &e1, const Element &e2) const { 104 return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) 105 && (e1.nextstate == e2.nextstate); 106 } 107 }; 108 109 private: 110 typedef std::unordered_map<Element, pair<StateId, ssize_t>, 111 ElementKey, ElementEqual> ElementMap; 112 113 const Fst<Arc> &fst_; 114 // Distance from state being expanded in epsilon-closure. 115 vector<Weight> *distance_; 116 // Shortest distance algorithm computation state. 117 ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_; 118 // Maps an element 'e' to a pair 'p' corresponding to a position 119 // in the arcs vector of the state being expanded. 'e' corresponds 120 // to the position 'p.second' in the 'arcs_' vector if 'p.first' is 121 // equal to the state being expanded. 122 ElementMap element_map_; 123 EpsilonArcFilter<Arc> eps_filter_; 124 stack<StateId> eps_queue_; // Queue used to visit the epsilon-closure 125 vector<bool> visited_; // '[i] = true' if state 'i' has been visited 126 std::forward_list<StateId> visited_states_; // List of visited states 127 vector<Arc> arcs_; // Arcs of state being expanded 128 Weight final_; // Final weight of state being expanded 129 130 void operator=(const RmEpsilonState); // Disallow 131 }; 132 133 134 template <class Arc, class Queue> 135 void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) { 136 sd_state_.ShortestDistance(source); 137 eps_queue_.push(source); 138 final_ = Weight::Zero(); 139 arcs_.clear(); 140 141 while (!eps_queue_.empty()) { 142 StateId state = eps_queue_.top(); 143 eps_queue_.pop(); 144 145 while ((StateId)visited_.size() <= state) visited_.push_back(false); 146 visited_[state] = true; 147 visited_states_.push_front(state); 148 149 for (ArcIterator< Fst<Arc> > ait(fst_, state); 150 !ait.Done(); 151 ait.Next()) { 152 Arc arc = ait.Value(); 153 arc.weight = Times((*distance_)[state], arc.weight); 154 155 if (eps_filter_(arc)) { 156 while ((StateId)visited_.size() <= arc.nextstate) 157 visited_.push_back(false); 158 if (!visited_[arc.nextstate]) 159 eps_queue_.push(arc.nextstate); 160 } else { 161 Element element(arc.ilabel, arc.olabel, arc.nextstate); 162 typename ElementMap::iterator it = element_map_.find(element); 163 if (it == element_map_.end()) { 164 element_map_.insert( 165 pair<Element, pair<StateId, ssize_t> > 166 (element, pair<StateId, ssize_t>(source, arcs_.size()))); 167 arcs_.push_back(arc); 168 } else { 169 if (((*it).second).first == source) { 170 Weight &w = arcs_[((*it).second).second].weight; 171 w = Plus(w, arc.weight); 172 } else { 173 ((*it).second).first = source; 174 ((*it).second).second = arcs_.size(); 175 arcs_.push_back(arc); 176 } 177 } 178 } 179 } 180 final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state))); 181 } 182 183 while (!visited_states_.empty()) { 184 visited_[visited_states_.front()] = false; 185 visited_states_.pop_front(); 186 } 187 } 188 189 190 // Removes epsilon-transitions (when both the input and output label 191 // are an epsilon) from a transducer. The result will be an equivalent 192 // FST that has no such epsilon transitions. This version modifies 193 // its input. It allows fine control via the options argument; see 194 // below for a simpler interface. 195 // 196 // The vector 'distance' will be used to hold the shortest distances 197 // during the epsilon-closure computation. The state queue discipline 198 // and convergence delta are taken in the options argument. 199 template <class Arc, class Queue> 200 void RmEpsilon(MutableFst<Arc> *fst, 201 vector<typename Arc::Weight> *distance, 202 const RmEpsilonOptions<Arc, Queue> &opts) { 203 typedef typename Arc::StateId StateId; 204 typedef typename Arc::Weight Weight; 205 typedef typename Arc::Label Label; 206 207 // States sorted in topological order when (acyclic) or generic 208 // topological order (cyclic). 209 vector<StateId> states; 210 211 if (fst->Properties(kTopSorted, false) & kTopSorted) { 212 for (StateId i = 0; i < (StateId)fst->NumStates(); i++) 213 states.push_back(i); 214 } else if (fst->Properties(kAcyclic, false) & kAcyclic) { 215 vector<StateId> order; 216 bool acyclic; 217 TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic); 218 DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>()); 219 if (!acyclic) 220 LOG(FATAL) << "RmEpsilon: not acyclic though property bit is set"; 221 states.resize(order.size()); 222 for (StateId i = 0; i < (StateId)order.size(); i++) 223 states[order[i]] = i; 224 } else { 225 uint64 props; 226 vector<StateId> scc; 227 SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props); 228 DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>()); 229 vector<StateId> first(scc.size(), kNoStateId); 230 vector<StateId> next(scc.size(), kNoStateId); 231 for (StateId i = 0; i < (StateId)scc.size(); i++) { 232 if (first[scc[i]] != kNoStateId) 233 next[i] = first[scc[i]]; 234 first[scc[i]] = i; 235 } 236 for (StateId i = 0; i < (StateId)first.size(); i++) 237 for (StateId j = first[i]; j != kNoStateId; j = next[j]) 238 states.push_back(j); 239 } 240 241 RmEpsilonState<Arc, Queue> 242 rmeps_state(*fst, distance, opts); 243 244 while (!states.empty()) { 245 StateId state = states.back(); 246 states.pop_back(); 247 rmeps_state.Expand(state); 248 fst->SetFinal(state, rmeps_state.Final()); 249 fst->DeleteArcs(state); 250 vector<Arc> &arcs = rmeps_state.Arcs(); 251 while (!arcs.empty()) { 252 fst->AddArc(state, arcs.back()); 253 arcs.pop_back(); 254 } 255 } 256 257 fst->SetProperties(RmEpsilonProperties( 258 fst->Properties(kFstProperties, false)), 259 kFstProperties); 260 261 if (opts.connect) 262 Connect(fst); 263 } 264 265 266 // Removes epsilon-transitions (when both the input and output label 267 // are an epsilon) from a transducer. The result will be an equivalent 268 // FST that has no such epsilon transitions. This version modifies its 269 // input. It has a simplified interface; see above for a version that 270 // allows finer control. 271 // 272 // Complexity: 273 // - Time: 274 // - Unweighted: O(V2 + V E) 275 // - Acyclic: O(V2 + V E) 276 // - Tropical semiring: O(V2 log V + V E) 277 // - General: exponential 278 // - Space: O(V E) 279 // where V = # of states visited, E = # of arcs. 280 // 281 // References: 282 // - Mehryar Mohri. Generic Epsilon-Removal and Input 283 // Epsilon-Normalization Algorithms for Weighted Transducers, 284 // "International Journal of Computer Science", 13(1):129-143 (2002). 285 template <class Arc> 286 void RmEpsilon(MutableFst<Arc> *fst, bool connect = true) { 287 typedef typename Arc::StateId StateId; 288 typedef typename Arc::Weight Weight; 289 typedef typename Arc::Label Label; 290 291 vector<Weight> distance; 292 AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>()); 293 RmEpsilonOptions<Arc, AutoQueue<StateId> > 294 opts(&state_queue, kDelta, connect); 295 296 RmEpsilon(fst, &distance, opts); 297 } 298 299 300 struct RmEpsilonFstOptions : CacheOptions { 301 float delta; 302 303 RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta) 304 : CacheOptions(opts), delta(delta) {} 305 306 explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {} 307 }; 308 309 310 // Implementation of delayed RmEpsilonFst. 311 template <class A> 312 class RmEpsilonFstImpl : public CacheImpl<A> { 313 public: 314 using FstImpl<A>::SetType; 315 using FstImpl<A>::SetProperties; 316 using FstImpl<A>::Properties; 317 using FstImpl<A>::SetInputSymbols; 318 using FstImpl<A>::SetOutputSymbols; 319 320 using CacheBaseImpl< CacheState<A> >::HasStart; 321 using CacheBaseImpl< CacheState<A> >::HasFinal; 322 using CacheBaseImpl< CacheState<A> >::HasArcs; 323 324 typedef typename A::Label Label; 325 typedef typename A::Weight Weight; 326 typedef typename A::StateId StateId; 327 typedef CacheState<A> State; 328 329 RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts) 330 : CacheImpl<A>(opts), 331 fst_(fst.Copy()), 332 rmeps_state_( 333 *fst_, 334 &distance_, 335 RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, opts.delta, 336 false) 337 ) { 338 SetType("rmepsilon"); 339 uint64 props = fst.Properties(kFstProperties, false); 340 SetProperties(RmEpsilonProperties(props, true), kCopyProperties); 341 } 342 343 ~RmEpsilonFstImpl() { 344 delete fst_; 345 } 346 347 StateId Start() { 348 if (!HasStart()) { 349 SetStart(fst_->Start()); 350 } 351 return CacheImpl<A>::Start(); 352 } 353 354 Weight Final(StateId s) { 355 if (!HasFinal(s)) { 356 Expand(s); 357 } 358 return CacheImpl<A>::Final(s); 359 } 360 361 size_t NumArcs(StateId s) { 362 if (!HasArcs(s)) 363 Expand(s); 364 return CacheImpl<A>::NumArcs(s); 365 } 366 367 size_t NumInputEpsilons(StateId s) { 368 if (!HasArcs(s)) 369 Expand(s); 370 return CacheImpl<A>::NumInputEpsilons(s); 371 } 372 373 size_t NumOutputEpsilons(StateId s) { 374 if (!HasArcs(s)) 375 Expand(s); 376 return CacheImpl<A>::NumOutputEpsilons(s); 377 } 378 379 void InitArcIterator(StateId s, ArcIteratorData<A> *data) { 380 if (!HasArcs(s)) 381 Expand(s); 382 CacheImpl<A>::InitArcIterator(s, data); 383 } 384 385 void Expand(StateId s) { 386 rmeps_state_.Expand(s); 387 SetFinal(s, rmeps_state_.Final()); 388 vector<A> &arcs = rmeps_state_.Arcs(); 389 while (!arcs.empty()) { 390 AddArc(s, arcs.back()); 391 arcs.pop_back(); 392 } 393 SetArcs(s); 394 } 395 396 private: 397 const Fst<A> *fst_; 398 vector<Weight> distance_; 399 FifoQueue<StateId> queue_; 400 RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_; 401 402 DISALLOW_EVIL_CONSTRUCTORS(RmEpsilonFstImpl); 403 }; 404 405 406 // Removes epsilon-transitions (when both the input and output label 407 // are an epsilon) from a transducer. The result will be an equivalent 408 // FST that has no such epsilon transitions. This version is a 409 // delayed Fst. 410 // 411 // Complexity: 412 // - Time: 413 // - Unweighted: O(v^2 + v e) 414 // - General: exponential 415 // - Space: O(v e) 416 // where v = # of states visited, e = # of arcs visited. Constant time 417 // to visit an input state or arc is assumed and exclusive of caching. 418 // 419 // References: 420 // - Mehryar Mohri. Generic Epsilon-Removal and Input 421 // Epsilon-Normalization Algorithms for Weighted Transducers, 422 // "International Journal of Computer Science", 13(1):129-143 (2002). 423 template <class A> 424 class RmEpsilonFst : public Fst<A> { 425 public: 426 friend class ArcIterator< RmEpsilonFst<A> >; 427 friend class CacheStateIterator< RmEpsilonFst<A> >; 428 friend class CacheArcIterator< RmEpsilonFst<A> >; 429 430 typedef A Arc; 431 typedef typename A::Weight Weight; 432 typedef typename A::StateId StateId; 433 typedef CacheState<A> State; 434 435 RmEpsilonFst(const Fst<A> &fst) 436 : impl_(new RmEpsilonFstImpl<A>(fst, RmEpsilonFstOptions())) {} 437 438 RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts) 439 : impl_(new RmEpsilonFstImpl<A>(fst, opts)) {} 440 441 explicit RmEpsilonFst(const RmEpsilonFst<A> &fst) : impl_(fst.impl_) { 442 impl_->IncrRefCount(); 443 } 444 445 virtual ~RmEpsilonFst() { if (!impl_->DecrRefCount()) delete impl_; } 446 447 virtual StateId Start() const { return impl_->Start(); } 448 449 virtual Weight Final(StateId s) const { return impl_->Final(s); } 450 451 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); } 452 453 virtual size_t NumInputEpsilons(StateId s) const { 454 return impl_->NumInputEpsilons(s); 455 } 456 457 virtual size_t NumOutputEpsilons(StateId s) const { 458 return impl_->NumOutputEpsilons(s); 459 } 460 461 virtual uint64 Properties(uint64 mask, bool test) const { 462 if (test) { 463 uint64 known, test = TestProperties(*this, mask, &known); 464 impl_->SetProperties(test, known); 465 return test & mask; 466 } else { 467 return impl_->Properties(mask); 468 } 469 } 470 471 virtual const string& Type() const { return impl_->Type(); } 472 473 virtual RmEpsilonFst<A> *Copy() const { 474 return new RmEpsilonFst<A>(*this); 475 } 476 477 virtual const SymbolTable* InputSymbols() const { 478 return impl_->InputSymbols(); 479 } 480 481 virtual const SymbolTable* OutputSymbols() const { 482 return impl_->OutputSymbols(); 483 } 484 485 virtual inline void InitStateIterator(StateIteratorData<A> *data) const; 486 487 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { 488 impl_->InitArcIterator(s, data); 489 } 490 491 protected: 492 RmEpsilonFstImpl<A> *Impl() { return impl_; } 493 494 private: 495 RmEpsilonFstImpl<A> *impl_; 496 497 void operator=(const RmEpsilonFst<A> &fst); // disallow 498 }; 499 500 501 // Specialization for RmEpsilonFst. 502 template<class A> 503 class StateIterator< RmEpsilonFst<A> > 504 : public CacheStateIterator< RmEpsilonFst<A> > { 505 public: 506 explicit StateIterator(const RmEpsilonFst<A> &fst) 507 : CacheStateIterator< RmEpsilonFst<A> >(fst) {} 508 }; 509 510 511 // Specialization for RmEpsilonFst. 512 template <class A> 513 class ArcIterator< RmEpsilonFst<A> > 514 : public CacheArcIterator< RmEpsilonFst<A> > { 515 public: 516 typedef typename A::StateId StateId; 517 518 ArcIterator(const RmEpsilonFst<A> &fst, StateId s) 519 : CacheArcIterator< RmEpsilonFst<A> >(fst, s) { 520 if (!fst.impl_->HasArcs(s)) 521 fst.impl_->Expand(s); 522 } 523 524 private: 525 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator); 526 }; 527 528 529 template <class A> inline 530 void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const { 531 data->base = new StateIterator< RmEpsilonFst<A> >(*this); 532 } 533 534 535 // Useful alias when using StdArc. 536 typedef RmEpsilonFst<StdArc> StdRmEpsilonFst; 537 538 } // namespace fst 539 540 #endif // FST_LIB_RMEPSILON_H__ 541