1 // rmepsilon.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: allauzen (at) google.com (Cyril Allauzen) 17 // 18 // \file 19 // Functions and classes that implemement epsilon-removal. 20 21 #ifndef FST_LIB_RMEPSILON_H__ 22 #define FST_LIB_RMEPSILON_H__ 23 24 #include <tr1/unordered_map> 25 using std::tr1::unordered_map; 26 using std::tr1::unordered_multimap; 27 #include <fst/slist.h> 28 #include <stack> 29 #include <string> 30 #include <utility> 31 using std::pair; using std::make_pair; 32 #include <vector> 33 using std::vector; 34 35 #include <fst/arcfilter.h> 36 #include <fst/cache.h> 37 #include <fst/connect.h> 38 #include <fst/factor-weight.h> 39 #include <fst/invert.h> 40 #include <fst/prune.h> 41 #include <fst/queue.h> 42 #include <fst/shortest-distance.h> 43 #include <fst/topsort.h> 44 45 46 namespace fst { 47 48 template <class Arc, class Queue> 49 class RmEpsilonOptions 50 : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > { 51 public: 52 typedef typename Arc::StateId StateId; 53 typedef typename Arc::Weight Weight; 54 55 bool connect; // Connect output 56 Weight weight_threshold; // Pruning weight threshold. 57 StateId state_threshold; // Pruning state threshold. 58 59 explicit RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true, 60 Weight w = Weight::Zero(), 61 StateId n = kNoStateId) 62 : ShortestDistanceOptions< Arc, Queue, EpsilonArcFilter<Arc> >( 63 q, EpsilonArcFilter<Arc>(), kNoStateId, d), 64 connect(c), weight_threshold(w), state_threshold(n) {} 65 private: 66 RmEpsilonOptions(); // disallow 67 }; 68 69 // Computation state of the epsilon-removal algorithm. 70 template <class Arc, class Queue> 71 class RmEpsilonState { 72 public: 73 typedef typename Arc::Label Label; 74 typedef typename Arc::StateId StateId; 75 typedef typename Arc::Weight Weight; 76 77 RmEpsilonState(const Fst<Arc> &fst, 78 vector<Weight> *distance, 79 const RmEpsilonOptions<Arc, Queue> &opts) 80 : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true), 81 expand_id_(0) {} 82 83 // Compute arcs and final weight for state 's' 84 void Expand(StateId s); 85 86 // Returns arcs of expanded state. 87 vector<Arc> &Arcs() { return arcs_; } 88 89 // Returns final weight of expanded state. 90 const Weight &Final() const { return final_; } 91 92 // Return true if an error has occured. 93 bool Error() const { return sd_state_.Error(); } 94 95 private: 96 static const size_t kPrime0 = 7853; 97 static const size_t kPrime1 = 7867; 98 99 struct Element { 100 Label ilabel; 101 Label olabel; 102 StateId nextstate; 103 104 Element() {} 105 106 Element(Label i, Label o, StateId s) 107 : ilabel(i), olabel(o), nextstate(s) {} 108 }; 109 110 class ElementKey { 111 public: 112 size_t operator()(const Element& e) const { 113 return static_cast<size_t>(e.nextstate + 114 e.ilabel * kPrime0 + 115 e.olabel * kPrime1); 116 } 117 118 private: 119 }; 120 121 class ElementEqual { 122 public: 123 bool operator()(const Element &e1, const Element &e2) const { 124 return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) 125 && (e1.nextstate == e2.nextstate); 126 } 127 }; 128 129 typedef unordered_map<Element, pair<StateId, size_t>, 130 ElementKey, ElementEqual> ElementMap; 131 132 const Fst<Arc> &fst_; 133 // Distance from state being expanded in epsilon-closure. 134 vector<Weight> *distance_; 135 // Shortest distance algorithm computation state. 136 ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_; 137 // Maps an element 'e' to a pair 'p' corresponding to a position 138 // in the arcs vector of the state being expanded. 'e' corresponds 139 // to the position 'p.second' in the 'arcs_' vector if 'p.first' is 140 // equal to the state being expanded. 141 ElementMap element_map_; 142 EpsilonArcFilter<Arc> eps_filter_; 143 stack<StateId> eps_queue_; // Queue used to visit the epsilon-closure 144 vector<bool> visited_; // '[i] = true' if state 'i' has been visited 145 slist<StateId> visited_states_; // List of visited states 146 vector<Arc> arcs_; // Arcs of state being expanded 147 Weight final_; // Final weight of state being expanded 148 StateId expand_id_; // Unique ID for each call to Expand 149 150 DISALLOW_COPY_AND_ASSIGN(RmEpsilonState); 151 }; 152 153 template <class Arc, class Queue> 154 const size_t RmEpsilonState<Arc, Queue>::kPrime0; 155 template <class Arc, class Queue> 156 const size_t RmEpsilonState<Arc, Queue>::kPrime1; 157 158 159 template <class Arc, class Queue> 160 void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) { 161 final_ = Weight::Zero(); 162 arcs_.clear(); 163 sd_state_.ShortestDistance(source); 164 if (sd_state_.Error()) 165 return; 166 eps_queue_.push(source); 167 168 while (!eps_queue_.empty()) { 169 StateId state = eps_queue_.top(); 170 eps_queue_.pop(); 171 172 while (visited_.size() <= state) visited_.push_back(false); 173 if (visited_[state]) continue; 174 visited_[state] = true; 175 visited_states_.push_front(state); 176 177 for (ArcIterator< Fst<Arc> > ait(fst_, state); 178 !ait.Done(); 179 ait.Next()) { 180 Arc arc = ait.Value(); 181 arc.weight = Times((*distance_)[state], arc.weight); 182 183 if (eps_filter_(arc)) { 184 while (visited_.size() <= arc.nextstate) 185 visited_.push_back(false); 186 if (!visited_[arc.nextstate]) 187 eps_queue_.push(arc.nextstate); 188 } else { 189 Element element(arc.ilabel, arc.olabel, arc.nextstate); 190 typename ElementMap::iterator it = element_map_.find(element); 191 if (it == element_map_.end()) { 192 element_map_.insert( 193 pair<Element, pair<StateId, size_t> > 194 (element, pair<StateId, size_t>(expand_id_, arcs_.size()))); 195 arcs_.push_back(arc); 196 } else { 197 if (((*it).second).first == expand_id_) { 198 Weight &w = arcs_[((*it).second).second].weight; 199 w = Plus(w, arc.weight); 200 } else { 201 ((*it).second).first = expand_id_; 202 ((*it).second).second = arcs_.size(); 203 arcs_.push_back(arc); 204 } 205 } 206 } 207 } 208 final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state))); 209 } 210 211 while (!visited_states_.empty()) { 212 visited_[visited_states_.front()] = false; 213 visited_states_.pop_front(); 214 } 215 ++expand_id_; 216 } 217 218 // Removes epsilon-transitions (when both the input and output label 219 // are an epsilon) from a transducer. The result will be an equivalent 220 // FST that has no such epsilon transitions. This version modifies 221 // its input. It allows fine control via the options argument; see 222 // below for a simpler interface. 223 // 224 // The vector 'distance' will be used to hold the shortest distances 225 // during the epsilon-closure computation. The state queue discipline 226 // and convergence delta are taken in the options argument. 227 template <class Arc, class Queue> 228 void RmEpsilon(MutableFst<Arc> *fst, 229 vector<typename Arc::Weight> *distance, 230 const RmEpsilonOptions<Arc, Queue> &opts) { 231 typedef typename Arc::StateId StateId; 232 typedef typename Arc::Weight Weight; 233 typedef typename Arc::Label Label; 234 235 if (fst->Start() == kNoStateId) { 236 return; 237 } 238 239 // 'noneps_in[s]' will be set to true iff 's' admits a non-epsilon 240 // incoming transition or is the start state. 241 vector<bool> noneps_in(fst->NumStates(), false); 242 noneps_in[fst->Start()] = true; 243 for (StateId i = 0; i < fst->NumStates(); ++i) { 244 for (ArcIterator<Fst<Arc> > aiter(*fst, i); 245 !aiter.Done(); 246 aiter.Next()) { 247 if (aiter.Value().ilabel != 0 || aiter.Value().olabel != 0) 248 noneps_in[aiter.Value().nextstate] = true; 249 } 250 } 251 252 // States sorted in topological order when (acyclic) or generic 253 // topological order (cyclic). 254 vector<StateId> states; 255 states.reserve(fst->NumStates()); 256 257 if (fst->Properties(kTopSorted, false) & kTopSorted) { 258 for (StateId i = 0; i < fst->NumStates(); i++) 259 states.push_back(i); 260 } else if (fst->Properties(kAcyclic, false) & kAcyclic) { 261 vector<StateId> order; 262 bool acyclic; 263 TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic); 264 DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>()); 265 // Sanity check: should be acyclic if property bit is set. 266 if(!acyclic) { 267 FSTERROR() << "RmEpsilon: inconsistent acyclic property bit"; 268 fst->SetProperties(kError, kError); 269 return; 270 } 271 states.resize(order.size()); 272 for (StateId i = 0; i < order.size(); i++) 273 states[order[i]] = i; 274 } else { 275 uint64 props; 276 vector<StateId> scc; 277 SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props); 278 DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>()); 279 vector<StateId> first(scc.size(), kNoStateId); 280 vector<StateId> next(scc.size(), kNoStateId); 281 for (StateId i = 0; i < scc.size(); i++) { 282 if (first[scc[i]] != kNoStateId) 283 next[i] = first[scc[i]]; 284 first[scc[i]] = i; 285 } 286 for (StateId i = 0; i < first.size(); i++) 287 for (StateId j = first[i]; j != kNoStateId; j = next[j]) 288 states.push_back(j); 289 } 290 291 RmEpsilonState<Arc, Queue> 292 rmeps_state(*fst, distance, opts); 293 294 while (!states.empty()) { 295 StateId state = states.back(); 296 states.pop_back(); 297 if (!noneps_in[state]) 298 continue; 299 rmeps_state.Expand(state); 300 fst->SetFinal(state, rmeps_state.Final()); 301 fst->DeleteArcs(state); 302 vector<Arc> &arcs = rmeps_state.Arcs(); 303 fst->ReserveArcs(state, arcs.size()); 304 while (!arcs.empty()) { 305 fst->AddArc(state, arcs.back()); 306 arcs.pop_back(); 307 } 308 } 309 310 for (StateId s = 0; s < fst->NumStates(); ++s) { 311 if (!noneps_in[s]) 312 fst->DeleteArcs(s); 313 } 314 315 if(rmeps_state.Error()) 316 fst->SetProperties(kError, kError); 317 fst->SetProperties( 318 RmEpsilonProperties(fst->Properties(kFstProperties, false)), 319 kFstProperties); 320 321 if (opts.weight_threshold != Weight::Zero() || 322 opts.state_threshold != kNoStateId) 323 Prune(fst, opts.weight_threshold, opts.state_threshold); 324 if (opts.connect && (opts.weight_threshold == Weight::Zero() || 325 opts.state_threshold != kNoStateId)) 326 Connect(fst); 327 } 328 329 // Removes epsilon-transitions (when both the input and output label 330 // are an epsilon) from a transducer. The result will be an equivalent 331 // FST that has no such epsilon transitions. This version modifies its 332 // input. It has a simplified interface; see above for a version that 333 // allows finer control. 334 // 335 // Complexity: 336 // - Time: 337 // - Unweighted: O(V2 + V E) 338 // - Acyclic: O(V2 + V E) 339 // - Tropical semiring: O(V2 log V + V E) 340 // - General: exponential 341 // - Space: O(V E) 342 // where V = # of states visited, E = # of arcs. 343 // 344 // References: 345 // - Mehryar Mohri. Generic Epsilon-Removal and Input 346 // Epsilon-Normalization Algorithms for Weighted Transducers, 347 // "International Journal of Computer Science", 13(1):129-143 (2002). 348 template <class Arc> 349 void RmEpsilon(MutableFst<Arc> *fst, 350 bool connect = true, 351 typename Arc::Weight weight_threshold = Arc::Weight::Zero(), 352 typename Arc::StateId state_threshold = kNoStateId, 353 float delta = kDelta) { 354 typedef typename Arc::StateId StateId; 355 typedef typename Arc::Weight Weight; 356 typedef typename Arc::Label Label; 357 358 vector<Weight> distance; 359 AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>()); 360 RmEpsilonOptions<Arc, AutoQueue<StateId> > 361 opts(&state_queue, delta, connect, weight_threshold, state_threshold); 362 363 RmEpsilon(fst, &distance, opts); 364 } 365 366 367 struct RmEpsilonFstOptions : CacheOptions { 368 float delta; 369 370 RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta) 371 : CacheOptions(opts), delta(delta) {} 372 373 explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {} 374 }; 375 376 377 // Implementation of delayed RmEpsilonFst. 378 template <class A> 379 class RmEpsilonFstImpl : public CacheImpl<A> { 380 public: 381 using FstImpl<A>::SetType; 382 using FstImpl<A>::SetProperties; 383 using FstImpl<A>::SetInputSymbols; 384 using FstImpl<A>::SetOutputSymbols; 385 386 using CacheBaseImpl< CacheState<A> >::PushArc; 387 using CacheBaseImpl< CacheState<A> >::HasArcs; 388 using CacheBaseImpl< CacheState<A> >::HasFinal; 389 using CacheBaseImpl< CacheState<A> >::HasStart; 390 using CacheBaseImpl< CacheState<A> >::SetArcs; 391 using CacheBaseImpl< CacheState<A> >::SetFinal; 392 using CacheBaseImpl< CacheState<A> >::SetStart; 393 394 typedef typename A::Label Label; 395 typedef typename A::Weight Weight; 396 typedef typename A::StateId StateId; 397 typedef CacheState<A> State; 398 399 RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts) 400 : CacheImpl<A>(opts), 401 fst_(fst.Copy()), 402 delta_(opts.delta), 403 rmeps_state_( 404 *fst_, 405 &distance_, 406 RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) { 407 SetType("rmepsilon"); 408 uint64 props = fst.Properties(kFstProperties, false); 409 SetProperties(RmEpsilonProperties(props, true), kCopyProperties); 410 SetInputSymbols(fst.InputSymbols()); 411 SetOutputSymbols(fst.OutputSymbols()); 412 } 413 414 RmEpsilonFstImpl(const RmEpsilonFstImpl &impl) 415 : CacheImpl<A>(impl), 416 fst_(impl.fst_->Copy(true)), 417 delta_(impl.delta_), 418 rmeps_state_( 419 *fst_, 420 &distance_, 421 RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) { 422 SetType("rmepsilon"); 423 SetProperties(impl.Properties(), kCopyProperties); 424 SetInputSymbols(impl.InputSymbols()); 425 SetOutputSymbols(impl.OutputSymbols()); 426 } 427 428 ~RmEpsilonFstImpl() { 429 delete fst_; 430 } 431 432 StateId Start() { 433 if (!HasStart()) { 434 SetStart(fst_->Start()); 435 } 436 return CacheImpl<A>::Start(); 437 } 438 439 Weight Final(StateId s) { 440 if (!HasFinal(s)) { 441 Expand(s); 442 } 443 return CacheImpl<A>::Final(s); 444 } 445 446 size_t NumArcs(StateId s) { 447 if (!HasArcs(s)) 448 Expand(s); 449 return CacheImpl<A>::NumArcs(s); 450 } 451 452 size_t NumInputEpsilons(StateId s) { 453 if (!HasArcs(s)) 454 Expand(s); 455 return CacheImpl<A>::NumInputEpsilons(s); 456 } 457 458 size_t NumOutputEpsilons(StateId s) { 459 if (!HasArcs(s)) 460 Expand(s); 461 return CacheImpl<A>::NumOutputEpsilons(s); 462 } 463 464 uint64 Properties() const { return Properties(kFstProperties); } 465 466 // Set error if found; return FST impl properties. 467 uint64 Properties(uint64 mask) const { 468 if ((mask & kError) && 469 (fst_->Properties(kError, false) || rmeps_state_.Error())) 470 SetProperties(kError, kError); 471 return FstImpl<A>::Properties(mask); 472 } 473 474 void InitArcIterator(StateId s, ArcIteratorData<A> *data) { 475 if (!HasArcs(s)) 476 Expand(s); 477 CacheImpl<A>::InitArcIterator(s, data); 478 } 479 480 void Expand(StateId s) { 481 rmeps_state_.Expand(s); 482 SetFinal(s, rmeps_state_.Final()); 483 vector<A> &arcs = rmeps_state_.Arcs(); 484 while (!arcs.empty()) { 485 PushArc(s, arcs.back()); 486 arcs.pop_back(); 487 } 488 SetArcs(s); 489 } 490 491 private: 492 const Fst<A> *fst_; 493 float delta_; 494 vector<Weight> distance_; 495 FifoQueue<StateId> queue_; 496 RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_; 497 498 void operator=(const RmEpsilonFstImpl<A> &); // disallow 499 }; 500 501 502 // Removes epsilon-transitions (when both the input and output label 503 // are an epsilon) from a transducer. The result will be an equivalent 504 // FST that has no such epsilon transitions. This version is a 505 // delayed Fst. 506 // 507 // Complexity: 508 // - Time: 509 // - Unweighted: O(v^2 + v e) 510 // - General: exponential 511 // - Space: O(v e) 512 // where v = # of states visited, e = # of arcs visited. Constant time 513 // to visit an input state or arc is assumed and exclusive of caching. 514 // 515 // References: 516 // - Mehryar Mohri. Generic Epsilon-Removal and Input 517 // Epsilon-Normalization Algorithms for Weighted Transducers, 518 // "International Journal of Computer Science", 13(1):129-143 (2002). 519 // 520 // This class attaches interface to implementation and handles 521 // reference counting, delegating most methods to ImplToFst. 522 template <class A> 523 class RmEpsilonFst : public ImplToFst< RmEpsilonFstImpl<A> > { 524 public: 525 friend class ArcIterator< RmEpsilonFst<A> >; 526 friend class StateIterator< RmEpsilonFst<A> >; 527 528 typedef A Arc; 529 typedef typename A::StateId StateId; 530 typedef CacheState<A> State; 531 typedef RmEpsilonFstImpl<A> Impl; 532 533 RmEpsilonFst(const Fst<A> &fst) 534 : ImplToFst<Impl>(new Impl(fst, RmEpsilonFstOptions())) {} 535 536 RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts) 537 : ImplToFst<Impl>(new Impl(fst, opts)) {} 538 539 // See Fst<>::Copy() for doc. 540 RmEpsilonFst(const RmEpsilonFst<A> &fst, bool safe = false) 541 : ImplToFst<Impl>(fst, safe) {} 542 543 // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc. 544 virtual RmEpsilonFst<A> *Copy(bool safe = false) const { 545 return new RmEpsilonFst<A>(*this, safe); 546 } 547 548 virtual inline void InitStateIterator(StateIteratorData<A> *data) const; 549 550 virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { 551 GetImpl()->InitArcIterator(s, data); 552 } 553 554 private: 555 // Makes visible to friends. 556 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } 557 558 void operator=(const RmEpsilonFst<A> &fst); // disallow 559 }; 560 561 // Specialization for RmEpsilonFst. 562 template<class A> 563 class StateIterator< RmEpsilonFst<A> > 564 : public CacheStateIterator< RmEpsilonFst<A> > { 565 public: 566 explicit StateIterator(const RmEpsilonFst<A> &fst) 567 : CacheStateIterator< RmEpsilonFst<A> >(fst, fst.GetImpl()) {} 568 }; 569 570 571 // Specialization for RmEpsilonFst. 572 template <class A> 573 class ArcIterator< RmEpsilonFst<A> > 574 : public CacheArcIterator< RmEpsilonFst<A> > { 575 public: 576 typedef typename A::StateId StateId; 577 578 ArcIterator(const RmEpsilonFst<A> &fst, StateId s) 579 : CacheArcIterator< RmEpsilonFst<A> >(fst.GetImpl(), s) { 580 if (!fst.GetImpl()->HasArcs(s)) 581 fst.GetImpl()->Expand(s); 582 } 583 584 private: 585 DISALLOW_COPY_AND_ASSIGN(ArcIterator); 586 }; 587 588 589 template <class A> inline 590 void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const { 591 data->base = new StateIterator< RmEpsilonFst<A> >(*this); 592 } 593 594 595 // Useful alias when using StdArc. 596 typedef RmEpsilonFst<StdArc> StdRmEpsilonFst; 597 598 } // namespace fst 599 600 #endif // FST_LIB_RMEPSILON_H__ 601