1 // expand.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley (at) google.com (Michael Riley) 17 // 18 // \file 19 // Expand a PDT to an FST. 20 21 #ifndef FST_EXTENSIONS_PDT_EXPAND_H__ 22 #define FST_EXTENSIONS_PDT_EXPAND_H__ 23 24 #include <vector> 25 using std::vector; 26 27 #include <fst/extensions/pdt/pdt.h> 28 #include <fst/extensions/pdt/paren.h> 29 #include <fst/extensions/pdt/shortest-path.h> 30 #include <fst/extensions/pdt/reverse.h> 31 #include <fst/cache.h> 32 #include <fst/mutable-fst.h> 33 #include <fst/queue.h> 34 #include <fst/state-table.h> 35 #include <fst/test-properties.h> 36 37 namespace fst { 38 39 template <class Arc> 40 struct ExpandFstOptions : public CacheOptions { 41 bool keep_parentheses; 42 PdtStack<typename Arc::StateId, typename Arc::Label> *stack; 43 PdtStateTable<typename Arc::StateId, typename Arc::StateId> *state_table; 44 45 ExpandFstOptions( 46 const CacheOptions &opts = CacheOptions(), 47 bool kp = false, 48 PdtStack<typename Arc::StateId, typename Arc::Label> *s = 0, 49 PdtStateTable<typename Arc::StateId, typename Arc::StateId> *st = 0) 50 : CacheOptions(opts), keep_parentheses(kp), stack(s), state_table(st) {} 51 }; 52 53 // Properties for an expanded PDT. 54 inline uint64 ExpandProperties(uint64 inprops) { 55 return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted); 56 } 57 58 59 // Implementation class for ExpandFst 60 template <class A> 61 class ExpandFstImpl 62 : public CacheImpl<A> { 63 public: 64 using FstImpl<A>::SetType; 65 using FstImpl<A>::SetProperties; 66 using FstImpl<A>::Properties; 67 using FstImpl<A>::SetInputSymbols; 68 using FstImpl<A>::SetOutputSymbols; 69 70 using CacheBaseImpl< CacheState<A> >::PushArc; 71 using CacheBaseImpl< CacheState<A> >::HasArcs; 72 using CacheBaseImpl< CacheState<A> >::HasFinal; 73 using CacheBaseImpl< CacheState<A> >::HasStart; 74 using CacheBaseImpl< CacheState<A> >::SetArcs; 75 using CacheBaseImpl< CacheState<A> >::SetFinal; 76 using CacheBaseImpl< CacheState<A> >::SetStart; 77 78 typedef A Arc; 79 typedef typename A::Label Label; 80 typedef typename A::Weight Weight; 81 typedef typename A::StateId StateId; 82 typedef StateId StackId; 83 typedef PdtStateTuple<StateId, StackId> StateTuple; 84 85 ExpandFstImpl(const Fst<A> &fst, 86 const vector<pair<typename Arc::Label, 87 typename Arc::Label> > &parens, 88 const ExpandFstOptions<A> &opts) 89 : CacheImpl<A>(opts), fst_(fst.Copy()), 90 stack_(opts.stack ? opts.stack: new PdtStack<StateId, Label>(parens)), 91 state_table_(opts.state_table ? opts.state_table : 92 new PdtStateTable<StateId, StackId>()), 93 own_stack_(opts.stack == 0), own_state_table_(opts.state_table == 0), 94 keep_parentheses_(opts.keep_parentheses) { 95 SetType("expand"); 96 97 uint64 props = fst.Properties(kFstProperties, false); 98 SetProperties(ExpandProperties(props), kCopyProperties); 99 100 SetInputSymbols(fst.InputSymbols()); 101 SetOutputSymbols(fst.OutputSymbols()); 102 } 103 104 ExpandFstImpl(const ExpandFstImpl &impl) 105 : CacheImpl<A>(impl), 106 fst_(impl.fst_->Copy(true)), 107 stack_(new PdtStack<StateId, Label>(*impl.stack_)), 108 state_table_(new PdtStateTable<StateId, StackId>()), 109 own_stack_(true), own_state_table_(true), 110 keep_parentheses_(impl.keep_parentheses_) { 111 SetType("expand"); 112 SetProperties(impl.Properties(), kCopyProperties); 113 SetInputSymbols(impl.InputSymbols()); 114 SetOutputSymbols(impl.OutputSymbols()); 115 } 116 117 ~ExpandFstImpl() { 118 delete fst_; 119 if (own_stack_) 120 delete stack_; 121 if (own_state_table_) 122 delete state_table_; 123 } 124 125 StateId Start() { 126 if (!HasStart()) { 127 StateId s = fst_->Start(); 128 if (s == kNoStateId) 129 return kNoStateId; 130 StateTuple tuple(s, 0); 131 StateId start = state_table_->FindState(tuple); 132 SetStart(start); 133 } 134 return CacheImpl<A>::Start(); 135 } 136 137 Weight Final(StateId s) { 138 if (!HasFinal(s)) { 139 const StateTuple &tuple = state_table_->Tuple(s); 140 Weight w = fst_->Final(tuple.state_id); 141 if (w != Weight::Zero() && tuple.stack_id == 0) 142 SetFinal(s, w); 143 else 144 SetFinal(s, Weight::Zero()); 145 } 146 return CacheImpl<A>::Final(s); 147 } 148 149 size_t NumArcs(StateId s) { 150 if (!HasArcs(s)) { 151 ExpandState(s); 152 } 153 return CacheImpl<A>::NumArcs(s); 154 } 155 156 size_t NumInputEpsilons(StateId s) { 157 if (!HasArcs(s)) 158 ExpandState(s); 159 return CacheImpl<A>::NumInputEpsilons(s); 160 } 161 162 size_t NumOutputEpsilons(StateId s) { 163 if (!HasArcs(s)) 164 ExpandState(s); 165 return CacheImpl<A>::NumOutputEpsilons(s); 166 } 167 168 void InitArcIterator(StateId s, ArcIteratorData<A> *data) { 169 if (!HasArcs(s)) 170 ExpandState(s); 171 CacheImpl<A>::InitArcIterator(s, data); 172 } 173 174 // Computes the outgoing transitions from a state, creating new destination 175 // states as needed. 176 void ExpandState(StateId s) { 177 StateTuple tuple = state_table_->Tuple(s); 178 for (ArcIterator< Fst<A> > aiter(*fst_, tuple.state_id); 179 !aiter.Done(); aiter.Next()) { 180 Arc arc = aiter.Value(); 181 StackId stack_id = stack_->Find(tuple.stack_id, arc.ilabel); 182 if (stack_id == -1) { 183 // Non-matching close parenthesis 184 continue; 185 } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) { 186 // Stack push/pop 187 arc.ilabel = arc.olabel = 0; 188 } 189 190 StateTuple ntuple(arc.nextstate, stack_id); 191 arc.nextstate = state_table_->FindState(ntuple); 192 PushArc(s, arc); 193 } 194 SetArcs(s); 195 } 196 197 const PdtStack<StackId, Label> &GetStack() const { return *stack_; } 198 199 const PdtStateTable<StateId, StackId> &GetStateTable() const { 200 return *state_table_; 201 } 202 203 private: 204 const Fst<A> *fst_; 205 206 PdtStack<StackId, Label> *stack_; 207 PdtStateTable<StateId, StackId> *state_table_; 208 bool own_stack_; 209 bool own_state_table_; 210 bool keep_parentheses_; 211 212 void operator=(const ExpandFstImpl<A> &); // disallow 213 }; 214 215 // Expands a pushdown transducer (PDT) encoded as an FST into an FST. 216 // This version is a delayed Fst. In the PDT, some transitions are 217 // labeled with open or close parentheses. To be interpreted as a PDT, 218 // the parens must balance on a path. The open-close parenthesis label 219 // pairs are passed in 'parens'. The expansion enforces the 220 // parenthesis constraints. The PDT must be expandable as an FST. 221 // 222 // This class attaches interface to implementation and handles 223 // reference counting, delegating most methods to ImplToFst. 224 template <class A> 225 class ExpandFst : public ImplToFst< ExpandFstImpl<A> > { 226 public: 227 friend class ArcIterator< ExpandFst<A> >; 228 friend class StateIterator< ExpandFst<A> >; 229 230 typedef A Arc; 231 typedef typename A::Label Label; 232 typedef typename A::Weight Weight; 233 typedef typename A::StateId StateId; 234 typedef StateId StackId; 235 typedef CacheState<A> State; 236 typedef ExpandFstImpl<A> Impl; 237 238 ExpandFst(const Fst<A> &fst, 239 const vector<pair<typename Arc::Label, 240 typename Arc::Label> > &parens) 241 : ImplToFst<Impl>(new Impl(fst, parens, ExpandFstOptions<A>())) {} 242 243 ExpandFst(const Fst<A> &fst, 244 const vector<pair<typename Arc::Label, 245 typename Arc::Label> > &parens, 246 const ExpandFstOptions<A> &opts) 247 : ImplToFst<Impl>(new Impl(fst, parens, opts)) {} 248 249 // See Fst<>::Copy() for doc. 250 ExpandFst(const ExpandFst<A> &fst, bool safe = false) 251 : ImplToFst<Impl>(fst, safe) {} 252 253 // Get a copy of this ExpandFst. See Fst<>::Copy() for further doc. 254 virtual ExpandFst<A> *Copy(bool safe = false) const { 255 return new ExpandFst<A>(*this, safe); 256 } 257 258 virtual inline void InitStateIterator(StateIteratorData<A> *data) const; 259 260 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { 261 GetImpl()->InitArcIterator(s, data); 262 } 263 264 const PdtStack<StackId, Label> &GetStack() const { 265 return GetImpl()->GetStack(); 266 } 267 268 const PdtStateTable<StateId, StackId> &GetStateTable() const { 269 return GetImpl()->GetStateTable(); 270 } 271 272 private: 273 // Makes visible to friends. 274 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } 275 276 void operator=(const ExpandFst<A> &fst); // Disallow 277 }; 278 279 280 // Specialization for ExpandFst. 281 template<class A> 282 class StateIterator< ExpandFst<A> > 283 : public CacheStateIterator< ExpandFst<A> > { 284 public: 285 explicit StateIterator(const ExpandFst<A> &fst) 286 : CacheStateIterator< ExpandFst<A> >(fst, fst.GetImpl()) {} 287 }; 288 289 290 // Specialization for ExpandFst. 291 template <class A> 292 class ArcIterator< ExpandFst<A> > 293 : public CacheArcIterator< ExpandFst<A> > { 294 public: 295 typedef typename A::StateId StateId; 296 297 ArcIterator(const ExpandFst<A> &fst, StateId s) 298 : CacheArcIterator< ExpandFst<A> >(fst.GetImpl(), s) { 299 if (!fst.GetImpl()->HasArcs(s)) 300 fst.GetImpl()->ExpandState(s); 301 } 302 303 private: 304 DISALLOW_COPY_AND_ASSIGN(ArcIterator); 305 }; 306 307 308 template <class A> inline 309 void ExpandFst<A>::InitStateIterator(StateIteratorData<A> *data) const 310 { 311 data->base = new StateIterator< ExpandFst<A> >(*this); 312 } 313 314 // 315 // PrunedExpand Class 316 // 317 318 // Prunes the delayed expansion of a pushdown transducer (PDT) encoded 319 // as an FST into an FST. In the PDT, some transitions are labeled 320 // with open or close parentheses. To be interpreted as a PDT, the 321 // parens must balance on a path. The open-close parenthesis label 322 // pairs are passed in 'parens'. The expansion enforces the 323 // parenthesis constraints. 324 // 325 // The algorithm works by visiting the delayed ExpandFst using a 326 // shortest-stack first queue discipline and relies on the 327 // shortest-distance information computed using a reverse 328 // shortest-path call to perform the pruning. 329 // 330 // The algorithm maintains the same state ordering between the ExpandFst 331 // being visited 'efst_' and the result of pruning written into the 332 // MutableFst 'ofst_' to improve readability of the code. 333 // 334 template <class A> 335 class PrunedExpand { 336 public: 337 typedef A Arc; 338 typedef typename A::Label Label; 339 typedef typename A::StateId StateId; 340 typedef typename A::Weight Weight; 341 typedef StateId StackId; 342 typedef PdtStack<StackId, Label> Stack; 343 typedef PdtStateTable<StateId, StackId> StateTable; 344 typedef typename PdtBalanceData<Arc>::SetIterator SetIterator; 345 346 // Constructor taking as input a PDT specified by 'ifst' and 'parens'. 347 // 'keep_parentheses' specifies whether parentheses are replaced by 348 // epsilons or not during the expansion. 'opts' is the cache options 349 // used to instantiate the underlying ExpandFst. 350 PrunedExpand(const Fst<A> &ifst, 351 const vector<pair<Label, Label> > &parens, 352 bool keep_parentheses = false, 353 const CacheOptions &opts = CacheOptions()) 354 : ifst_(ifst.Copy()), 355 keep_parentheses_(keep_parentheses), 356 stack_(parens), 357 efst_(ifst, parens, 358 ExpandFstOptions<Arc>(opts, true, &stack_, &state_table_)), 359 queue_(state_table_, stack_, stack_length_, distance_, fdistance_) { 360 Reverse(*ifst_, parens, &rfst_); 361 VectorFst<Arc> path; 362 reverse_shortest_path_ = new SP( 363 rfst_, parens, 364 PdtShortestPathOptions<A, FifoQueue<StateId> >(true, false)); 365 reverse_shortest_path_->ShortestPath(&path); 366 balance_data_ = reverse_shortest_path_->GetBalanceData()->Reverse( 367 rfst_.NumStates(), 10, -1); 368 369 InitCloseParenMultimap(parens); 370 } 371 372 ~PrunedExpand() { 373 delete ifst_; 374 delete reverse_shortest_path_; 375 delete balance_data_; 376 } 377 378 // Expands and prunes with weight threshold 'threshold' the input PDT. 379 // Writes the result in 'ofst'. 380 void Expand(MutableFst<A> *ofst, const Weight &threshold); 381 382 private: 383 static const uint8 kEnqueued; 384 static const uint8 kExpanded; 385 static const uint8 kSourceState; 386 387 // Comparison functor used by the queue: 388 // 1. states corresponding to shortest stack first, 389 // 2. among stacks of the same length, reverse lexicographic order is used, 390 // 3. among states with the same stack, shortest-first order is used. 391 class StackCompare { 392 public: 393 StackCompare(const StateTable &st, 394 const Stack &s, const vector<StackId> &sl, 395 const vector<Weight> &d, const vector<Weight> &fd) 396 : state_table_(st), stack_(s), stack_length_(sl), 397 distance_(d), fdistance_(fd) {} 398 399 bool operator()(StateId s1, StateId s2) const { 400 StackId si1 = state_table_.Tuple(s1).stack_id; 401 StackId si2 = state_table_.Tuple(s2).stack_id; 402 if (stack_length_[si1] < stack_length_[si2]) 403 return true; 404 if (stack_length_[si1] > stack_length_[si2]) 405 return false; 406 // If stack id equal, use A* 407 if (si1 == si2) { 408 Weight w1 = (s1 < distance_.size()) && (s1 < fdistance_.size()) ? 409 Times(distance_[s1], fdistance_[s1]) : Weight::Zero(); 410 Weight w2 = (s2 < distance_.size()) && (s2 < fdistance_.size()) ? 411 Times(distance_[s2], fdistance_[s2]) : Weight::Zero(); 412 return less_(w1, w2); 413 } 414 // If lenghts are equal, use reverse lexico. 415 for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) { 416 if (stack_.Top(si1) < stack_.Top(si2)) return true; 417 if (stack_.Top(si1) > stack_.Top(si2)) return false; 418 } 419 return false; 420 } 421 422 private: 423 const StateTable &state_table_; 424 const Stack &stack_; 425 const vector<StackId> &stack_length_; 426 const vector<Weight> &distance_; 427 const vector<Weight> &fdistance_; 428 NaturalLess<Weight> less_; 429 }; 430 431 class ShortestStackFirstQueue 432 : public ShortestFirstQueue<StateId, StackCompare> { 433 public: 434 ShortestStackFirstQueue( 435 const PdtStateTable<StateId, StackId> &st, 436 const Stack &s, 437 const vector<StackId> &sl, 438 const vector<Weight> &d, const vector<Weight> &fd) 439 : ShortestFirstQueue<StateId, StackCompare>( 440 StackCompare(st, s, sl, d, fd)) {} 441 }; 442 443 444 void InitCloseParenMultimap(const vector<pair<Label, Label> > &parens); 445 Weight DistanceToDest(StateId state, StateId source) const; 446 uint8 Flags(StateId s) const; 447 void SetFlags(StateId s, uint8 flags, uint8 mask); 448 Weight Distance(StateId s) const; 449 void SetDistance(StateId s, Weight w); 450 Weight FinalDistance(StateId s) const; 451 void SetFinalDistance(StateId s, Weight w); 452 StateId SourceState(StateId s) const; 453 void SetSourceState(StateId s, StateId p); 454 void AddStateAndEnqueue(StateId s); 455 void Relax(StateId s, const A &arc, Weight w); 456 bool PruneArc(StateId s, const A &arc); 457 void ProcStart(); 458 void ProcFinal(StateId s); 459 bool ProcNonParen(StateId s, const A &arc, bool add_arc); 460 bool ProcOpenParen(StateId s, const A &arc, StackId si, StackId nsi); 461 bool ProcCloseParen(StateId s, const A &arc); 462 void ProcDestStates(StateId s, StackId si); 463 464 Fst<A> *ifst_; // Input PDT 465 VectorFst<Arc> rfst_; // Reversed PDT 466 bool keep_parentheses_; // Keep parentheses in ofst? 467 StateTable state_table_; // State table for efst_ 468 Stack stack_; // Stack trie 469 ExpandFst<Arc> efst_; // Expanded PDT 470 vector<StackId> stack_length_; // Length of stack for given stack id 471 vector<Weight> distance_; // Distance from initial state in efst_/ofst 472 vector<Weight> fdistance_; // Distance to final states in efst_/ofst 473 ShortestStackFirstQueue queue_; // Queue used to visit efst_ 474 vector<uint8> flags_; // Status flags for states in efst_/ofst 475 vector<StateId> sources_; // PDT source state for each expanded state 476 477 typedef PdtShortestPath<Arc, FifoQueue<StateId> > SP; 478 typedef typename SP::CloseParenMultimap ParenMultimap; 479 SP *reverse_shortest_path_; // Shortest path for rfst_ 480 PdtBalanceData<Arc> *balance_data_; // Not owned by shortest_path_ 481 ParenMultimap close_paren_multimap_; // Maps open paren arcs to 482 // balancing close paren arcs. 483 484 MutableFst<Arc> *ofst_; // Output fst 485 Weight limit_; // Weight limit 486 487 typedef unordered_map<StateId, Weight> DestMap; 488 DestMap dest_map_; 489 StackId current_stack_id_; 490 // 'current_stack_id_' is the stack id of the states currently at the top 491 // of queue, i.e., the states currently being popped and processed. 492 // 'dest_map_' maps a state 's' in 'ifst_' that is the source 493 // of a close parentheses matching the top of 'current_stack_id_; to 494 // the shortest-distance from '(s, current_stack_id_)' to the final 495 // states in 'efst_'. 496 ssize_t current_paren_id_; // Paren id at top of current stack 497 ssize_t cached_stack_id_; 498 StateId cached_source_; 499 slist<pair<StateId, Weight> > cached_dest_list_; 500 // 'cached_dest_list_' contains the set of pair of destination 501 // states and weight to final states for source state 502 // 'cached_source_' and paren id 'cached_paren_id': the set of 503 // source state of a close parenthesis with paren id 504 // 'cached_paren_id' balancing an incoming open parenthesis with 505 // paren id 'cached_paren_id' in state 'cached_source_'. 506 507 NaturalLess<Weight> less_; 508 }; 509 510 template <class A> const uint8 PrunedExpand<A>::kEnqueued = 0x01; 511 template <class A> const uint8 PrunedExpand<A>::kExpanded = 0x02; 512 template <class A> const uint8 PrunedExpand<A>::kSourceState = 0x04; 513 514 515 // Initializes close paren multimap, mapping pairs (s,paren_id) to 516 // all the arcs out of s labeled with close parenthese for paren_id. 517 template <class A> 518 void PrunedExpand<A>::InitCloseParenMultimap( 519 const vector<pair<Label, Label> > &parens) { 520 unordered_map<Label, Label> paren_id_map; 521 for (Label i = 0; i < parens.size(); ++i) { 522 const pair<Label, Label> &p = parens[i]; 523 paren_id_map[p.first] = i; 524 paren_id_map[p.second] = i; 525 } 526 527 for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) { 528 StateId s = siter.Value(); 529 for (ArcIterator<Fst<Arc> > aiter(*ifst_, s); 530 !aiter.Done(); aiter.Next()) { 531 const Arc &arc = aiter.Value(); 532 typename unordered_map<Label, Label>::const_iterator pit 533 = paren_id_map.find(arc.ilabel); 534 if (pit == paren_id_map.end()) continue; 535 if (arc.ilabel == parens[pit->second].second) { // Close paren 536 ParenState<Arc> paren_state(pit->second, s); 537 close_paren_multimap_.insert(make_pair(paren_state, arc)); 538 } 539 } 540 } 541 } 542 543 544 // Returns the weight of the shortest balanced path from 'source' to 'dest' 545 // in 'ifst_', 'dest' must be the source state of a close paren arc. 546 template <class A> 547 typename A::Weight PrunedExpand<A>::DistanceToDest(StateId source, 548 StateId dest) const { 549 typename SP::SearchState s(source + 1, dest + 1); 550 VLOG(2) << "D(" << source << ", " << dest << ") =" 551 << reverse_shortest_path_->GetShortestPathData().Distance(s); 552 return reverse_shortest_path_->GetShortestPathData().Distance(s); 553 } 554 555 // Returns the flags for state 's' in 'ofst_'. 556 template <class A> 557 uint8 PrunedExpand<A>::Flags(StateId s) const { 558 return s < flags_.size() ? flags_[s] : 0; 559 } 560 561 // Modifies the flags for state 's' in 'ofst_'. 562 template <class A> 563 void PrunedExpand<A>::SetFlags(StateId s, uint8 flags, uint8 mask) { 564 while (flags_.size() <= s) flags_.push_back(0); 565 flags_[s] &= ~mask; 566 flags_[s] |= flags & mask; 567 } 568 569 570 // Returns the shortest distance from the initial state to 's' in 'ofst_'. 571 template <class A> 572 typename A::Weight PrunedExpand<A>::Distance(StateId s) const { 573 return s < distance_.size() ? distance_[s] : Weight::Zero(); 574 } 575 576 // Sets the shortest distance from the initial state to 's' in 'ofst_' to 'w'. 577 template <class A> 578 void PrunedExpand<A>::SetDistance(StateId s, Weight w) { 579 while (distance_.size() <= s ) distance_.push_back(Weight::Zero()); 580 distance_[s] = w; 581 } 582 583 584 // Returns the shortest distance from 's' to the final states in 'ofst_'. 585 template <class A> 586 typename A::Weight PrunedExpand<A>::FinalDistance(StateId s) const { 587 return s < fdistance_.size() ? fdistance_[s] : Weight::Zero(); 588 } 589 590 // Sets the shortest distance from 's' to the final states in 'ofst_' to 'w'. 591 template <class A> 592 void PrunedExpand<A>::SetFinalDistance(StateId s, Weight w) { 593 while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero()); 594 fdistance_[s] = w; 595 } 596 597 // Returns the PDT "source" state of state 's' in 'ofst_'. 598 template <class A> 599 typename A::StateId PrunedExpand<A>::SourceState(StateId s) const { 600 return s < sources_.size() ? sources_[s] : kNoStateId; 601 } 602 603 // Sets the PDT "source" state of state 's' in 'ofst_' to state 'p' in 'ifst_'. 604 template <class A> 605 void PrunedExpand<A>::SetSourceState(StateId s, StateId p) { 606 while (sources_.size() <= s) sources_.push_back(kNoStateId); 607 sources_[s] = p; 608 } 609 610 // Adds state 's' of 'efst_' to 'ofst_' and inserts it in the queue, 611 // modifying the flags for 's' accordingly. 612 template <class A> 613 void PrunedExpand<A>::AddStateAndEnqueue(StateId s) { 614 if (!(Flags(s) & (kEnqueued | kExpanded))) { 615 while (ofst_->NumStates() <= s) ofst_->AddState(); 616 queue_.Enqueue(s); 617 SetFlags(s, kEnqueued, kEnqueued); 618 } else if (Flags(s) & kEnqueued) { 619 queue_.Update(s); 620 } 621 // TODO(allauzen): Check everything is fine when kExpanded? 622 } 623 624 // Relaxes arc 'arc' out of state 's' in 'ofst_': 625 // * if the distance to 's' times the weight of 'arc' is smaller than 626 // the currently stored distance for 'arc.nextstate', 627 // updates 'Distance(arc.nextstate)' with new estimate; 628 // * if 'fd' is less than the currently stored distance from 'arc.nextstate' 629 // to the final state, updates with new estimate. 630 template <class A> 631 void PrunedExpand<A>::Relax(StateId s, const A &arc, Weight fd) { 632 Weight nd = Times(Distance(s), arc.weight); 633 if (less_(nd, Distance(arc.nextstate))) { 634 SetDistance(arc.nextstate, nd); 635 SetSourceState(arc.nextstate, SourceState(s)); 636 } 637 if (less_(fd, FinalDistance(arc.nextstate))) 638 SetFinalDistance(arc.nextstate, fd); 639 VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to " 640 << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate) 641 << ", nd = " << nd; 642 } 643 644 // Returns 'true' if the arc 'arc' out of state 's' in 'efst_' needs to 645 // be pruned. 646 template <class A> 647 bool PrunedExpand<A>::PruneArc(StateId s, const A &arc) { 648 VLOG(2) << "Prune ?"; 649 Weight fd = Weight::Zero(); 650 651 if ((cached_source_ != SourceState(s)) || 652 (cached_stack_id_ != current_stack_id_)) { 653 cached_source_ = SourceState(s); 654 cached_stack_id_ = current_stack_id_; 655 cached_dest_list_.clear(); 656 if (cached_source_ != ifst_->Start()) { 657 for (SetIterator set_iter = 658 balance_data_->Find(current_paren_id_, cached_source_); 659 !set_iter.Done(); set_iter.Next()) { 660 StateId dest = set_iter.Element(); 661 typename DestMap::const_iterator iter = dest_map_.find(dest); 662 cached_dest_list_.push_front(*iter); 663 } 664 } else { 665 // TODO(allauzen): queue discipline should prevent this never 666 // from happening; replace by a check. 667 cached_dest_list_.push_front( 668 make_pair(rfst_.Start() -1, Weight::One())); 669 } 670 } 671 672 for (typename slist<pair<StateId, Weight> >::const_iterator iter = 673 cached_dest_list_.begin(); 674 iter != cached_dest_list_.end(); 675 ++iter) { 676 fd = Plus(fd, 677 Times(DistanceToDest(state_table_.Tuple(arc.nextstate).state_id, 678 iter->first), 679 iter->second)); 680 } 681 Relax(s, arc, fd); 682 Weight w = Times(Distance(s), Times(arc.weight, fd)); 683 return less_(limit_, w); 684 } 685 686 // Adds start state of 'efst_' to 'ofst_', enqueues it and initializes 687 // the distance data structures. 688 template <class A> 689 void PrunedExpand<A>::ProcStart() { 690 StateId s = efst_.Start(); 691 AddStateAndEnqueue(s); 692 ofst_->SetStart(s); 693 SetSourceState(s, ifst_->Start()); 694 695 current_stack_id_ = 0; 696 current_paren_id_ = -1; 697 stack_length_.push_back(0); 698 dest_map_[rfst_.Start() - 1] = Weight::One(); // not needed 699 700 cached_source_ = ifst_->Start(); 701 cached_stack_id_ = 0; 702 cached_dest_list_.push_front( 703 make_pair(rfst_.Start() -1, Weight::One())); 704 705 PdtStateTuple<StateId, StackId> tuple(rfst_.Start() - 1, 0); 706 SetFinalDistance(state_table_.FindState(tuple), Weight::One()); 707 SetDistance(s, Weight::One()); 708 SetFinalDistance(s, DistanceToDest(ifst_->Start(), rfst_.Start() - 1)); 709 VLOG(2) << DistanceToDest(ifst_->Start(), rfst_.Start() - 1); 710 } 711 712 // Makes 's' final in 'ofst_' if shortest accepting path ending in 's' 713 // is below threshold. 714 template <class A> 715 void PrunedExpand<A>::ProcFinal(StateId s) { 716 Weight final = efst_.Final(s); 717 if ((final == Weight::Zero()) || less_(limit_, Times(Distance(s), final))) 718 return; 719 ofst_->SetFinal(s, final); 720 } 721 722 // Returns true when arc (or meta-arc) 'arc' out of 's' in 'efst_' is 723 // below the threshold. When 'add_arc' is true, 'arc' is added to 'ofst_'. 724 template <class A> 725 bool PrunedExpand<A>::ProcNonParen(StateId s, const A &arc, bool add_arc) { 726 VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate 727 << ", " << arc.ilabel << ":" << arc.olabel << " / " << arc.weight 728 << ", add_arc = " << (add_arc ? "true" : "false"); 729 if (PruneArc(s, arc)) return false; 730 if(add_arc) ofst_->AddArc(s, arc); 731 AddStateAndEnqueue(arc.nextstate); 732 return true; 733 } 734 735 // Processes an open paren arc 'arc' out of state 's' in 'ofst_'. 736 // When 'arc' is labeled with an open paren, 737 // 1. considers each (shortest) balanced path starting in 's' by 738 // taking 'arc' and ending by a close paren balancing the open 739 // paren of 'arc' as a meta-arc, processes and prunes each meta-arc 740 // as a non-paren arc, inserting its destination to the queue; 741 // 2. if at least one of these meta-arcs has not been pruned, 742 // adds the destination of 'arc' to 'ofst_' as a new source state 743 // for the stack id 'nsi' and inserts it in the queue. 744 template <class A> 745 bool PrunedExpand<A>::ProcOpenParen(StateId s, const A &arc, StackId si, 746 StackId nsi) { 747 // Update the stack lenght when needed: |nsi| = |si| + 1. 748 while (stack_length_.size() <= nsi) stack_length_.push_back(-1); 749 if (stack_length_[nsi] == -1) 750 stack_length_[nsi] = stack_length_[si] + 1; 751 752 StateId ns = arc.nextstate; 753 VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id 754 << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")"; 755 bool proc_arc = false; 756 Weight fd = Weight::Zero(); 757 ssize_t paren_id = stack_.ParenId(arc.ilabel); 758 slist<StateId> sources; 759 for (SetIterator set_iter = 760 balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id); 761 !set_iter.Done(); set_iter.Next()) { 762 sources.push_front(set_iter.Element()); 763 } 764 for (typename slist<StateId>::const_iterator sources_iter = sources.begin(); 765 sources_iter != sources.end(); 766 ++ sources_iter) { 767 StateId source = *sources_iter; 768 VLOG(2) << "Close paren source: " << source; 769 ParenState<Arc> paren_state(paren_id, source); 770 for (typename ParenMultimap::const_iterator iter = 771 close_paren_multimap_.find(paren_state); 772 iter != close_paren_multimap_.end() && paren_state == iter->first; 773 ++iter) { 774 Arc meta_arc = iter->second; 775 PdtStateTuple<StateId, StackId> tuple(meta_arc.nextstate, si); 776 meta_arc.nextstate = state_table_.FindState(tuple); 777 VLOG(2) << state_table_.Tuple(ns).state_id << ", " << source; 778 VLOG(2) << "Meta arc weight = " << arc.weight << " Times " 779 << DistanceToDest(state_table_.Tuple(ns).state_id, source) 780 << " Times " << meta_arc.weight; 781 meta_arc.weight = Times( 782 arc.weight, 783 Times(DistanceToDest(state_table_.Tuple(ns).state_id, source), 784 meta_arc.weight)); 785 proc_arc |= ProcNonParen(s, meta_arc, false); 786 fd = Plus(fd, Times( 787 Times( 788 DistanceToDest(state_table_.Tuple(ns).state_id, source), 789 iter->second.weight), 790 FinalDistance(meta_arc.nextstate))); 791 } 792 } 793 if (proc_arc) { 794 VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate; 795 ofst_->AddArc( 796 s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate)); 797 AddStateAndEnqueue(arc.nextstate); 798 Weight nd = Times(Distance(s), arc.weight); 799 if(less_(nd, Distance(arc.nextstate))) 800 SetDistance(arc.nextstate, nd); 801 // FinalDistance not necessary for source state since pruning 802 // decided using the meta-arcs above. But this is a problem with 803 // A*, hence: 804 if (less_(fd, FinalDistance(arc.nextstate))) 805 SetFinalDistance(arc.nextstate, fd); 806 SetFlags(arc.nextstate, kSourceState, kSourceState); 807 } 808 return proc_arc; 809 } 810 811 // Checks that shortest path through close paren arc in 'efst_' is 812 // below threshold, if so adds it to 'ofst_'. 813 template <class A> 814 bool PrunedExpand<A>::ProcCloseParen(StateId s, const A &arc) { 815 Weight w = Times(Distance(s), 816 Times(arc.weight, FinalDistance(arc.nextstate))); 817 if (less_(limit_, w)) 818 return false; 819 ofst_->AddArc( 820 s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate)); 821 return true; 822 } 823 824 // When 's' in 'ofst_' is a source state for stack id 'si', identifies 825 // all the corresponding possible destination states, that is, all the 826 // states in 'ifst_' that have an outgoing close paren arc balancing 827 // the incoming open paren taken to get to 's', and for each such 828 // state 't', computes the shortest distance from (t, si) to the final 829 // states in 'ofst_'. Stores this information in 'dest_map_'. 830 template <class A> 831 void PrunedExpand<A>::ProcDestStates(StateId s, StackId si) { 832 if (!(Flags(s) & kSourceState)) return; 833 if (si != current_stack_id_) { 834 dest_map_.clear(); 835 current_stack_id_ = si; 836 current_paren_id_ = stack_.Top(current_stack_id_); 837 VLOG(2) << "StackID " << si << " dequeued for first time"; 838 } 839 // TODO(allauzen): clean up source state business; rename current function to 840 // ProcSourceState. 841 SetSourceState(s, state_table_.Tuple(s).state_id); 842 843 ssize_t paren_id = stack_.Top(si); 844 for (SetIterator set_iter = 845 balance_data_->Find(paren_id, state_table_.Tuple(s).state_id); 846 !set_iter.Done(); set_iter.Next()) { 847 StateId dest_state = set_iter.Element(); 848 if (dest_map_.find(dest_state) != dest_map_.end()) 849 continue; 850 Weight dest_weight = Weight::Zero(); 851 ParenState<Arc> paren_state(paren_id, dest_state); 852 for (typename ParenMultimap::const_iterator iter = 853 close_paren_multimap_.find(paren_state); 854 iter != close_paren_multimap_.end() && paren_state == iter->first; 855 ++iter) { 856 const Arc &arc = iter->second; 857 PdtStateTuple<StateId, StackId> tuple(arc.nextstate, stack_.Pop(si)); 858 dest_weight = Plus(dest_weight, 859 Times(arc.weight, 860 FinalDistance(state_table_.FindState(tuple)))); 861 } 862 dest_map_[dest_state] = dest_weight; 863 VLOG(2) << "State " << dest_state << " is a dest state for stack id " 864 << si << " with weight " << dest_weight; 865 } 866 } 867 868 // Expands and prunes with weight threshold 'threshold' the input PDT. 869 // Writes the result in 'ofst'. 870 template <class A> 871 void PrunedExpand<A>::Expand( 872 MutableFst<A> *ofst, const typename A::Weight &threshold) { 873 ofst_ = ofst; 874 ofst_->DeleteStates(); 875 ofst_->SetInputSymbols(ifst_->InputSymbols()); 876 ofst_->SetOutputSymbols(ifst_->OutputSymbols()); 877 878 limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold); 879 flags_.clear(); 880 881 ProcStart(); 882 883 while (!queue_.Empty()) { 884 StateId s = queue_.Head(); 885 queue_.Dequeue(); 886 SetFlags(s, kExpanded, kExpanded | kEnqueued); 887 VLOG(2) << s << " dequeued!"; 888 889 ProcFinal(s); 890 StackId stack_id = state_table_.Tuple(s).stack_id; 891 ProcDestStates(s, stack_id); 892 893 for (ArcIterator<ExpandFst<Arc> > aiter(efst_, s); 894 !aiter.Done(); 895 aiter.Next()) { 896 Arc arc = aiter.Value(); 897 StackId nextstack_id = state_table_.Tuple(arc.nextstate).stack_id; 898 if (stack_id == nextstack_id) 899 ProcNonParen(s, arc, true); 900 else if (stack_id == stack_.Pop(nextstack_id)) 901 ProcOpenParen(s, arc, stack_id, nextstack_id); 902 else 903 ProcCloseParen(s, arc); 904 } 905 VLOG(2) << "d[" << s << "] = " << Distance(s) 906 << ", fd[" << s << "] = " << FinalDistance(s); 907 } 908 } 909 910 // 911 // Expand() Functions 912 // 913 914 template <class Arc> 915 struct ExpandOptions { 916 bool connect; 917 bool keep_parentheses; 918 typename Arc::Weight weight_threshold; 919 920 ExpandOptions(bool c = true, bool k = false, 921 typename Arc::Weight w = Arc::Weight::Zero()) 922 : connect(c), keep_parentheses(k), weight_threshold(w) {} 923 }; 924 925 // Expands a pushdown transducer (PDT) encoded as an FST into an FST. 926 // This version writes the expanded PDT result to a MutableFst. 927 // In the PDT, some transitions are labeled with open or close 928 // parentheses. To be interpreted as a PDT, the parens must balance on 929 // a path. The open-close parenthesis label pairs are passed in 930 // 'parens'. The expansion enforces the parenthesis constraints. The 931 // PDT must be expandable as an FST. 932 template <class Arc> 933 void Expand( 934 const Fst<Arc> &ifst, 935 const vector<pair<typename Arc::Label, typename Arc::Label> > &parens, 936 MutableFst<Arc> *ofst, 937 const ExpandOptions<Arc> &opts) { 938 typedef typename Arc::Label Label; 939 typedef typename Arc::StateId StateId; 940 typedef typename Arc::Weight Weight; 941 typedef typename ExpandFst<Arc>::StackId StackId; 942 943 ExpandFstOptions<Arc> eopts; 944 eopts.gc_limit = 0; 945 if (opts.weight_threshold == Weight::Zero()) { 946 eopts.keep_parentheses = opts.keep_parentheses; 947 *ofst = ExpandFst<Arc>(ifst, parens, eopts); 948 } else { 949 PrunedExpand<Arc> pruned_expand(ifst, parens, opts.keep_parentheses); 950 pruned_expand.Expand(ofst, opts.weight_threshold); 951 } 952 953 if (opts.connect) 954 Connect(ofst); 955 } 956 957 // Expands a pushdown transducer (PDT) encoded as an FST into an FST. 958 // This version writes the expanded PDT result to a MutableFst. 959 // In the PDT, some transitions are labeled with open or close 960 // parentheses. To be interpreted as a PDT, the parens must balance on 961 // a path. The open-close parenthesis label pairs are passed in 962 // 'parens'. The expansion enforces the parenthesis constraints. The 963 // PDT must be expandable as an FST. 964 template<class Arc> 965 void Expand( 966 const Fst<Arc> &ifst, 967 const vector<pair<typename Arc::Label, typename Arc::Label> > &parens, 968 MutableFst<Arc> *ofst, 969 bool connect = true, bool keep_parentheses = false) { 970 Expand(ifst, parens, ofst, ExpandOptions<Arc>(connect, keep_parentheses)); 971 } 972 973 } // namespace fst 974 975 #endif // FST_EXTENSIONS_PDT_EXPAND_H__ 976