1 // shortest-path.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley (at) google.com (Michael Riley) 17 // 18 // \file 19 // Functions to find shortest paths in a PDT. 20 21 #ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H__ 22 #define FST_EXTENSIONS_PDT_SHORTEST_PATH_H__ 23 24 #include <fst/shortest-path.h> 25 #include <fst/extensions/pdt/paren.h> 26 #include <fst/extensions/pdt/pdt.h> 27 28 #include <tr1/unordered_map> 29 using std::tr1::unordered_map; 30 using std::tr1::unordered_multimap; 31 #include <tr1/unordered_set> 32 using std::tr1::unordered_set; 33 using std::tr1::unordered_multiset; 34 #include <stack> 35 #include <vector> 36 using std::vector; 37 38 namespace fst { 39 40 template <class Arc, class Queue> 41 struct PdtShortestPathOptions { 42 bool keep_parentheses; 43 bool path_gc; 44 45 PdtShortestPathOptions(bool kp = false, bool gc = true) 46 : keep_parentheses(kp), path_gc(gc) {} 47 }; 48 49 50 // Class to store PDT shortest path results. Stores shortest path 51 // tree info 'Distance()', Parent(), and ArcParent() information keyed 52 // on two types: 53 // (1) By SearchState: This is a usual node in a shortest path tree but: 54 // (a) is w.r.t a PDT search state - a pair of a PDT state and 55 // a 'start' state, which is either the PDT start state or 56 // the destination state of an open parenthesis. 57 // (b) the Distance() is from this 'start' state to the search state. 58 // (c) Parent().state is kNoLabel for the 'start' state. 59 // 60 // (2) By ParenSpec: This connects shortest path trees depending on the 61 // the parenthesis taken. Given the parenthesis spec: 62 // (a) the Distance() is from the Parent() 'start' state to the 63 // parenthesis destination state. 64 // (b) the ArcParent() is the parenthesis arc. 65 template <class Arc> 66 class PdtShortestPathData { 67 public: 68 static const uint8 kFinal; 69 70 typedef typename Arc::StateId StateId; 71 typedef typename Arc::Weight Weight; 72 typedef typename Arc::Label Label; 73 74 struct SearchState { 75 SearchState() : state(kNoStateId), start(kNoStateId) {} 76 77 SearchState(StateId s, StateId t) : state(s), start(t) {} 78 79 bool operator==(const SearchState &s) const { 80 if (&s == this) 81 return true; 82 return s.state == this->state && s.start == this->start; 83 } 84 85 StateId state; // PDT state 86 StateId start; // PDT paren 'source' state 87 }; 88 89 90 // Specifies paren id, source and dest 'start' states of a paren. 91 // These are the 'start' states of the respective sub-graphs. 92 struct ParenSpec { 93 ParenSpec() 94 : paren_id(kNoLabel), src_start(kNoStateId), dest_start(kNoStateId) {} 95 96 ParenSpec(Label id, StateId s, StateId d) 97 : paren_id(id), src_start(s), dest_start(d) {} 98 99 Label paren_id; // Id of parenthesis 100 StateId src_start; // sub-graph 'start' state for paren source. 101 StateId dest_start; // sub-graph 'start' state for paren dest. 102 103 bool operator==(const ParenSpec &x) const { 104 if (&x == this) 105 return true; 106 return x.paren_id == this->paren_id && 107 x.src_start == this->src_start && 108 x.dest_start == this->dest_start; 109 } 110 }; 111 112 struct SearchData { 113 SearchData() : distance(Weight::Zero()), 114 parent(kNoStateId, kNoStateId), 115 paren_id(kNoLabel), 116 flags(0) {} 117 118 Weight distance; // Distance to this state from PDT 'start' state 119 SearchState parent; // Parent state in shortest path tree 120 int16 paren_id; // If parent arc has paren, paren ID, o.w. kNoLabel 121 uint8 flags; // First byte reserved for PdtShortestPathData use 122 }; 123 124 PdtShortestPathData(bool gc) 125 : state_(kNoStateId, kNoStateId), 126 paren_(kNoLabel, kNoStateId, kNoStateId), 127 gc_(gc), 128 nstates_(0), 129 ngc_(0), 130 finished_(false) {} 131 132 ~PdtShortestPathData() { 133 VLOG(1) << "opm size: " << paren_map_.size(); 134 VLOG(1) << "# of search states: " << nstates_; 135 if (gc_) 136 VLOG(1) << "# of GC'd search states: " << ngc_; 137 } 138 139 void Clear() { 140 search_map_.clear(); 141 search_multimap_.clear(); 142 paren_map_.clear(); 143 state_ = SearchState(kNoStateId, kNoStateId); 144 nstates_ = 0; 145 ngc_ = 0; 146 } 147 148 Weight Distance(SearchState s) const { 149 SearchData *data = GetSearchData(s); 150 return data->distance; 151 } 152 153 Weight Distance(const ParenSpec &paren) const { 154 SearchData *data = GetSearchData(paren); 155 return data->distance; 156 } 157 158 SearchState Parent(SearchState s) const { 159 SearchData *data = GetSearchData(s); 160 return data->parent; 161 } 162 163 SearchState Parent(const ParenSpec &paren) const { 164 SearchData *data = GetSearchData(paren); 165 return data->parent; 166 } 167 168 Label ParenId(SearchState s) const { 169 SearchData *data = GetSearchData(s); 170 return data->paren_id; 171 } 172 173 uint8 Flags(SearchState s) const { 174 SearchData *data = GetSearchData(s); 175 return data->flags; 176 } 177 178 void SetDistance(SearchState s, Weight w) { 179 SearchData *data = GetSearchData(s); 180 data->distance = w; 181 } 182 183 void SetDistance(const ParenSpec &paren, Weight w) { 184 SearchData *data = GetSearchData(paren); 185 data->distance = w; 186 } 187 188 void SetParent(SearchState s, SearchState p) { 189 SearchData *data = GetSearchData(s); 190 data->parent = p; 191 } 192 193 void SetParent(const ParenSpec &paren, SearchState p) { 194 SearchData *data = GetSearchData(paren); 195 data->parent = p; 196 } 197 198 void SetParenId(SearchState s, Label p) { 199 if (p >= 32768) 200 FSTERROR() << "PdtShortestPathData: Paren ID does not fits in an int16"; 201 SearchData *data = GetSearchData(s); 202 data->paren_id = p; 203 } 204 205 void SetFlags(SearchState s, uint8 f, uint8 mask) { 206 SearchData *data = GetSearchData(s); 207 data->flags &= ~mask; 208 data->flags |= f & mask; 209 } 210 211 void GC(StateId s); 212 213 void Finish() { finished_ = true; } 214 215 private: 216 static const Arc kNoArc; 217 static const size_t kPrime0; 218 static const size_t kPrime1; 219 static const uint8 kInited; 220 static const uint8 kMarked; 221 222 // Hash for search state 223 struct SearchStateHash { 224 size_t operator()(const SearchState &s) const { 225 return s.state + s.start * kPrime0; 226 } 227 }; 228 229 // Hash for paren map 230 struct ParenHash { 231 size_t operator()(const ParenSpec &paren) const { 232 return paren.paren_id + paren.src_start * kPrime0 + 233 paren.dest_start * kPrime1; 234 } 235 }; 236 237 typedef unordered_map<SearchState, SearchData, SearchStateHash> SearchMap; 238 239 typedef unordered_multimap<StateId, StateId> SearchMultimap; 240 241 // Hash map from paren spec to open paren data 242 typedef unordered_map<ParenSpec, SearchData, ParenHash> ParenMap; 243 244 SearchData *GetSearchData(SearchState s) const { 245 if (s == state_) 246 return state_data_; 247 if (finished_) { 248 typename SearchMap::iterator it = search_map_.find(s); 249 if (it == search_map_.end()) 250 return &null_search_data_; 251 state_ = s; 252 return state_data_ = &(it->second); 253 } else { 254 state_ = s; 255 state_data_ = &search_map_[s]; 256 if (!(state_data_->flags & kInited)) { 257 ++nstates_; 258 if (gc_) 259 search_multimap_.insert(make_pair(s.start, s.state)); 260 state_data_->flags = kInited; 261 } 262 return state_data_; 263 } 264 } 265 266 SearchData *GetSearchData(ParenSpec paren) const { 267 if (paren == paren_) 268 return paren_data_; 269 if (finished_) { 270 typename ParenMap::iterator it = paren_map_.find(paren); 271 if (it == paren_map_.end()) 272 return &null_search_data_; 273 paren_ = paren; 274 return state_data_ = &(it->second); 275 } else { 276 paren_ = paren; 277 return paren_data_ = &paren_map_[paren]; 278 } 279 } 280 281 mutable SearchMap search_map_; // Maps from search state to data 282 mutable SearchMultimap search_multimap_; // Maps from 'start' to subgraph 283 mutable ParenMap paren_map_; // Maps paren spec to search data 284 mutable SearchState state_; // Last state accessed 285 mutable SearchData *state_data_; // Last state data accessed 286 mutable ParenSpec paren_; // Last paren spec accessed 287 mutable SearchData *paren_data_; // Last paren data accessed 288 bool gc_; // Allow GC? 289 mutable size_t nstates_; // Total number of search states 290 size_t ngc_; // Number of GC'd search states 291 mutable SearchData null_search_data_; // Null search data 292 bool finished_; // Read-only access when true 293 294 DISALLOW_COPY_AND_ASSIGN(PdtShortestPathData); 295 }; 296 297 // Deletes inaccessible search data from a given 'start' (open paren dest) 298 // state. Assumes 'final' (close paren source or PDT final) states have 299 // been flagged 'kFinal'. 300 template<class Arc> 301 void PdtShortestPathData<Arc>::GC(StateId start) { 302 if (!gc_) 303 return; 304 vector<StateId> final; 305 for (typename SearchMultimap::iterator mmit = search_multimap_.find(start); 306 mmit != search_multimap_.end() && mmit->first == start; 307 ++mmit) { 308 SearchState s(mmit->second, start); 309 const SearchData &data = search_map_[s]; 310 if (data.flags & kFinal) 311 final.push_back(s.state); 312 } 313 314 // Mark phase 315 for (size_t i = 0; i < final.size(); ++i) { 316 SearchState s(final[i], start); 317 while (s.state != kNoLabel) { 318 SearchData *sdata = &search_map_[s]; 319 if (sdata->flags & kMarked) 320 break; 321 sdata->flags |= kMarked; 322 SearchState p = sdata->parent; 323 if (p.start != start && p.start != kNoLabel) { // entering sub-subgraph 324 ParenSpec paren(sdata->paren_id, s.start, p.start); 325 SearchData *pdata = &paren_map_[paren]; 326 s = pdata->parent; 327 } else { 328 s = p; 329 } 330 } 331 } 332 333 // Sweep phase 334 typename SearchMultimap::iterator mmit = search_multimap_.find(start); 335 while (mmit != search_multimap_.end() && mmit->first == start) { 336 SearchState s(mmit->second, start); 337 typename SearchMap::iterator mit = search_map_.find(s); 338 const SearchData &data = mit->second; 339 if (!(data.flags & kMarked)) { 340 search_map_.erase(mit); 341 ++ngc_; 342 } 343 search_multimap_.erase(mmit++); 344 } 345 } 346 347 template<class Arc> const Arc PdtShortestPathData<Arc>::kNoArc 348 = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId); 349 350 template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime0 = 7853; 351 352 template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime1 = 7867; 353 354 template<class Arc> const uint8 PdtShortestPathData<Arc>::kInited = 0x01; 355 356 template<class Arc> const uint8 PdtShortestPathData<Arc>::kFinal = 0x02; 357 358 template<class Arc> const uint8 PdtShortestPathData<Arc>::kMarked = 0x04; 359 360 361 // This computes the single source shortest (balanced) path (SSSP) 362 // through a weighted PDT that has a bounded stack (i.e. is expandable 363 // as an FST). It is a generalization of the classic SSSP graph 364 // algorithm that removes a state s from a queue (defined by a 365 // user-provided queue type) and relaxes the destination states of 366 // transitions leaving s. In this PDT version, states that have 367 // entering open parentheses are treated as source states for a 368 // sub-graph SSSP problem with the shortest path up to the open 369 // parenthesis being first saved. When a close parenthesis is then 370 // encountered any balancing open parenthesis is examined for this 371 // saved information and multiplied back. In this way, each sub-graph 372 // is entered only once rather than repeatedly. If every state in the 373 // input PDT has the property that there is a unique 'start' state for 374 // it with entering open parentheses, then this algorithm is quite 375 // straight-forward. In general, this will not be the case, so the 376 // algorithm (implicitly) creates a new graph where each state is a 377 // pair of an original state and a possible parenthesis 'start' state 378 // for that state. 379 template<class Arc, class Queue> 380 class PdtShortestPath { 381 public: 382 typedef typename Arc::StateId StateId; 383 typedef typename Arc::Weight Weight; 384 typedef typename Arc::Label Label; 385 386 typedef PdtShortestPathData<Arc> SpData; 387 typedef typename SpData::SearchState SearchState; 388 typedef typename SpData::ParenSpec ParenSpec; 389 390 typedef typename PdtBalanceData<Arc>::SetIterator CloseSourceIterator; 391 392 PdtShortestPath(const Fst<Arc> &ifst, 393 const vector<pair<Label, Label> > &parens, 394 const PdtShortestPathOptions<Arc, Queue> &opts) 395 : kFinal(SpData::kFinal), 396 ifst_(ifst.Copy()), 397 parens_(parens), 398 keep_parens_(opts.keep_parentheses), 399 start_(ifst.Start()), 400 sp_data_(opts.path_gc), 401 error_(false) { 402 403 if ((Weight::Properties() & (kPath | kRightSemiring)) 404 != (kPath | kRightSemiring)) { 405 FSTERROR() << "PdtShortestPath: Weight needs to have the path" 406 << " property and be right distributive: " << Weight::Type(); 407 error_ = true; 408 } 409 410 for (Label i = 0; i < parens.size(); ++i) { 411 const pair<Label, Label> &p = parens[i]; 412 paren_id_map_[p.first] = i; 413 paren_id_map_[p.second] = i; 414 } 415 }; 416 417 ~PdtShortestPath() { 418 VLOG(1) << "# of input states: " << CountStates(*ifst_); 419 VLOG(1) << "# of enqueued: " << nenqueued_; 420 VLOG(1) << "cpmm size: " << close_paren_multimap_.size(); 421 delete ifst_; 422 } 423 424 void ShortestPath(MutableFst<Arc> *ofst) { 425 Init(ofst); 426 GetDistance(start_); 427 GetPath(); 428 sp_data_.Finish(); 429 if (error_) ofst->SetProperties(kError, kError); 430 } 431 432 const PdtShortestPathData<Arc> &GetShortestPathData() const { 433 return sp_data_; 434 } 435 436 PdtBalanceData<Arc> *GetBalanceData() { return &balance_data_; } 437 438 private: 439 static const Arc kNoArc; 440 static const uint8 kEnqueued; 441 static const uint8 kExpanded; 442 static const uint8 kFinished; 443 const uint8 kFinal; 444 445 public: 446 // Hash multimap from close paren label to an paren arc. 447 typedef unordered_multimap<ParenState<Arc>, Arc, 448 typename ParenState<Arc>::Hash> CloseParenMultimap; 449 450 const CloseParenMultimap &GetCloseParenMultimap() const { 451 return close_paren_multimap_; 452 } 453 454 private: 455 void Init(MutableFst<Arc> *ofst); 456 void GetDistance(StateId start); 457 void ProcFinal(SearchState s); 458 void ProcArcs(SearchState s); 459 void ProcOpenParen(Label paren_id, SearchState s, Arc arc, Weight w); 460 void ProcCloseParen(Label paren_id, SearchState s, const Arc &arc, Weight w); 461 void ProcNonParen(SearchState s, const Arc &arc, Weight w); 462 void Relax(SearchState s, SearchState t, Arc arc, Weight w, Label paren_id); 463 void Enqueue(SearchState d); 464 void GetPath(); 465 Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open); 466 467 Fst<Arc> *ifst_; 468 MutableFst<Arc> *ofst_; 469 const vector<pair<Label, Label> > &parens_; 470 bool keep_parens_; 471 Queue *state_queue_; // current state queue 472 StateId start_; 473 Weight f_distance_; 474 SearchState f_parent_; 475 SpData sp_data_; 476 unordered_map<Label, Label> paren_id_map_; 477 CloseParenMultimap close_paren_multimap_; 478 PdtBalanceData<Arc> balance_data_; 479 ssize_t nenqueued_; 480 bool error_; 481 482 DISALLOW_COPY_AND_ASSIGN(PdtShortestPath); 483 }; 484 485 template<class Arc, class Queue> 486 void PdtShortestPath<Arc, Queue>::Init(MutableFst<Arc> *ofst) { 487 ofst_ = ofst; 488 ofst->DeleteStates(); 489 ofst->SetInputSymbols(ifst_->InputSymbols()); 490 ofst->SetOutputSymbols(ifst_->OutputSymbols()); 491 492 if (ifst_->Start() == kNoStateId) 493 return; 494 495 f_distance_ = Weight::Zero(); 496 f_parent_ = SearchState(kNoStateId, kNoStateId); 497 498 sp_data_.Clear(); 499 close_paren_multimap_.clear(); 500 balance_data_.Clear(); 501 nenqueued_ = 0; 502 503 // Find open parens per destination state and close parens per source state. 504 for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) { 505 StateId s = siter.Value(); 506 for (ArcIterator<Fst<Arc> > aiter(*ifst_, s); 507 !aiter.Done(); aiter.Next()) { 508 const Arc &arc = aiter.Value(); 509 typename unordered_map<Label, Label>::const_iterator pit 510 = paren_id_map_.find(arc.ilabel); 511 if (pit != paren_id_map_.end()) { // Is a paren? 512 Label paren_id = pit->second; 513 if (arc.ilabel == parens_[paren_id].first) { // Open paren 514 balance_data_.OpenInsert(paren_id, arc.nextstate); 515 } else { // Close paren 516 ParenState<Arc> paren_state(paren_id, s); 517 close_paren_multimap_.insert(make_pair(paren_state, arc)); 518 } 519 } 520 } 521 } 522 } 523 524 // Computes the shortest distance stored in a recursive way. Each 525 // sub-graph (i.e. different paren 'start' state) begins with weight One(). 526 template<class Arc, class Queue> 527 void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) { 528 if (start == kNoStateId) 529 return; 530 531 Queue state_queue; 532 state_queue_ = &state_queue; 533 SearchState q(start, start); 534 Enqueue(q); 535 sp_data_.SetDistance(q, Weight::One()); 536 537 while (!state_queue_->Empty()) { 538 StateId state = state_queue_->Head(); 539 state_queue_->Dequeue(); 540 SearchState s(state, start); 541 sp_data_.SetFlags(s, 0, kEnqueued); 542 ProcFinal(s); 543 ProcArcs(s); 544 sp_data_.SetFlags(s, kExpanded, kExpanded); 545 } 546 sp_data_.SetFlags(q, kFinished, kFinished); 547 balance_data_.FinishInsert(start); 548 sp_data_.GC(start); 549 } 550 551 // Updates best complete path. 552 template<class Arc, class Queue> 553 void PdtShortestPath<Arc, Queue>::ProcFinal(SearchState s) { 554 if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) { 555 Weight w = Times(sp_data_.Distance(s), 556 ifst_->Final(s.state)); 557 if (f_distance_ != Plus(f_distance_, w)) { 558 if (f_parent_.state != kNoStateId) 559 sp_data_.SetFlags(f_parent_, 0, kFinal); 560 sp_data_.SetFlags(s, kFinal, kFinal); 561 562 f_distance_ = Plus(f_distance_, w); 563 f_parent_ = s; 564 } 565 } 566 } 567 568 // Processes all arcs leaving the state s. 569 template<class Arc, class Queue> 570 void PdtShortestPath<Arc, Queue>::ProcArcs(SearchState s) { 571 for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state); 572 !aiter.Done(); 573 aiter.Next()) { 574 Arc arc = aiter.Value(); 575 Weight w = Times(sp_data_.Distance(s), arc.weight); 576 577 typename unordered_map<Label, Label>::const_iterator pit 578 = paren_id_map_.find(arc.ilabel); 579 if (pit != paren_id_map_.end()) { // Is a paren? 580 Label paren_id = pit->second; 581 if (arc.ilabel == parens_[paren_id].first) 582 ProcOpenParen(paren_id, s, arc, w); 583 else 584 ProcCloseParen(paren_id, s, arc, w); 585 } else { 586 ProcNonParen(s, arc, w); 587 } 588 } 589 } 590 591 // Saves the shortest path info for reaching this parenthesis 592 // and starts a new SSSP in the sub-graph pointed to by the parenthesis 593 // if previously unvisited. Otherwise it finds any previously encountered 594 // closing parentheses and relaxes them using the recursively stored 595 // shortest distance to them. 596 template<class Arc, class Queue> inline 597 void PdtShortestPath<Arc, Queue>::ProcOpenParen( 598 Label paren_id, SearchState s, Arc arc, Weight w) { 599 600 SearchState d(arc.nextstate, arc.nextstate); 601 ParenSpec paren(paren_id, s.start, d.start); 602 Weight pdist = sp_data_.Distance(paren); 603 if (pdist != Plus(pdist, w)) { 604 sp_data_.SetDistance(paren, w); 605 sp_data_.SetParent(paren, s); 606 Weight dist = sp_data_.Distance(d); 607 if (dist == Weight::Zero()) { 608 Queue *state_queue = state_queue_; 609 GetDistance(d.start); 610 state_queue_ = state_queue; 611 } else if (!(sp_data_.Flags(d) & kFinished)) { 612 FSTERROR() << "PdtShortestPath: open parenthesis recursion: not bounded stack"; 613 error_ = true; 614 } 615 616 for (CloseSourceIterator set_iter = 617 balance_data_.Find(paren_id, arc.nextstate); 618 !set_iter.Done(); set_iter.Next()) { 619 SearchState cpstate(set_iter.Element(), d.start); 620 ParenState<Arc> paren_state(paren_id, cpstate.state); 621 for (typename CloseParenMultimap::const_iterator cpit = 622 close_paren_multimap_.find(paren_state); 623 cpit != close_paren_multimap_.end() && paren_state == cpit->first; 624 ++cpit) { 625 const Arc &cparc = cpit->second; 626 Weight cpw = Times(w, Times(sp_data_.Distance(cpstate), 627 cparc.weight)); 628 Relax(cpstate, s, cparc, cpw, paren_id); 629 } 630 } 631 } 632 } 633 634 // Saves the correspondence between each closing parenthesis and its 635 // balancing open parenthesis info. Relaxes any close parenthesis 636 // destination state that has a balancing previously encountered open 637 // parenthesis. 638 template<class Arc, class Queue> inline 639 void PdtShortestPath<Arc, Queue>::ProcCloseParen( 640 Label paren_id, SearchState s, const Arc &arc, Weight w) { 641 ParenState<Arc> paren_state(paren_id, s.start); 642 if (!(sp_data_.Flags(s) & kExpanded)) { 643 balance_data_.CloseInsert(paren_id, s.start, s.state); 644 sp_data_.SetFlags(s, kFinal, kFinal); 645 } 646 } 647 648 // For non-parentheses, classical relaxation. 649 template<class Arc, class Queue> inline 650 void PdtShortestPath<Arc, Queue>::ProcNonParen( 651 SearchState s, const Arc &arc, Weight w) { 652 Relax(s, s, arc, w, kNoLabel); 653 } 654 655 // Classical relaxation on the search graph for 'arc' from state 's'. 656 // State 't' is in the same sub-graph as the nextstate should be (i.e. 657 // has the same paren 'start'. 658 template<class Arc, class Queue> inline 659 void PdtShortestPath<Arc, Queue>::Relax( 660 SearchState s, SearchState t, Arc arc, Weight w, Label paren_id) { 661 SearchState d(arc.nextstate, t.start); 662 Weight dist = sp_data_.Distance(d); 663 if (dist != Plus(dist, w)) { 664 sp_data_.SetParent(d, s); 665 sp_data_.SetParenId(d, paren_id); 666 sp_data_.SetDistance(d, Plus(dist, w)); 667 Enqueue(d); 668 } 669 } 670 671 template<class Arc, class Queue> inline 672 void PdtShortestPath<Arc, Queue>::Enqueue(SearchState s) { 673 if (!(sp_data_.Flags(s) & kEnqueued)) { 674 state_queue_->Enqueue(s.state); 675 sp_data_.SetFlags(s, kEnqueued, kEnqueued); 676 ++nenqueued_; 677 } else { 678 state_queue_->Update(s.state); 679 } 680 } 681 682 // Follows parent pointers to find the shortest path. Uses a stack 683 // since the shortest distance is stored recursively. 684 template<class Arc, class Queue> 685 void PdtShortestPath<Arc, Queue>::GetPath() { 686 SearchState s = f_parent_, d = SearchState(kNoStateId, kNoStateId); 687 StateId s_p = kNoStateId, d_p = kNoStateId; 688 Arc arc(kNoArc); 689 Label paren_id = kNoLabel; 690 stack<ParenSpec> paren_stack; 691 while (s.state != kNoStateId) { 692 d_p = s_p; 693 s_p = ofst_->AddState(); 694 if (d.state == kNoStateId) { 695 ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state)); 696 } else { 697 if (paren_id != kNoLabel) { // paren? 698 if (arc.ilabel == parens_[paren_id].first) { // open paren 699 paren_stack.pop(); 700 } else { // close paren 701 ParenSpec paren(paren_id, d.start, s.start); 702 paren_stack.push(paren); 703 } 704 if (!keep_parens_) 705 arc.ilabel = arc.olabel = 0; 706 } 707 arc.nextstate = d_p; 708 ofst_->AddArc(s_p, arc); 709 } 710 d = s; 711 s = sp_data_.Parent(d); 712 paren_id = sp_data_.ParenId(d); 713 if (s.state != kNoStateId) { 714 arc = GetPathArc(s, d, paren_id, false); 715 } else if (!paren_stack.empty()) { 716 ParenSpec paren = paren_stack.top(); 717 s = sp_data_.Parent(paren); 718 paren_id = paren.paren_id; 719 arc = GetPathArc(s, d, paren_id, true); 720 } 721 } 722 ofst_->SetStart(s_p); 723 ofst_->SetProperties( 724 ShortestPathProperties(ofst_->Properties(kFstProperties, false)), 725 kFstProperties); 726 } 727 728 729 // Finds transition with least weight between two states with label matching 730 // paren_id and open/close paren type or a non-paren if kNoLabel. 731 template<class Arc, class Queue> 732 Arc PdtShortestPath<Arc, Queue>::GetPathArc( 733 SearchState s, SearchState d, Label paren_id, bool open_paren) { 734 Arc path_arc = kNoArc; 735 for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state); 736 !aiter.Done(); 737 aiter.Next()) { 738 const Arc &arc = aiter.Value(); 739 if (arc.nextstate != d.state) 740 continue; 741 Label arc_paren_id = kNoLabel; 742 typename unordered_map<Label, Label>::const_iterator pit 743 = paren_id_map_.find(arc.ilabel); 744 if (pit != paren_id_map_.end()) { 745 arc_paren_id = pit->second; 746 bool arc_open_paren = arc.ilabel == parens_[arc_paren_id].first; 747 if (arc_open_paren != open_paren) 748 continue; 749 } 750 if (arc_paren_id != paren_id) 751 continue; 752 if (arc.weight == Plus(arc.weight, path_arc.weight)) 753 path_arc = arc; 754 } 755 if (path_arc.nextstate == kNoStateId) { 756 FSTERROR() << "PdtShortestPath::GetPathArc failed to find arc"; 757 error_ = true; 758 } 759 return path_arc; 760 } 761 762 template<class Arc, class Queue> 763 const Arc PdtShortestPath<Arc, Queue>::kNoArc 764 = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId); 765 766 template<class Arc, class Queue> 767 const uint8 PdtShortestPath<Arc, Queue>::kEnqueued = 0x10; 768 769 template<class Arc, class Queue> 770 const uint8 PdtShortestPath<Arc, Queue>::kExpanded = 0x20; 771 772 template<class Arc, class Queue> 773 const uint8 PdtShortestPath<Arc, Queue>::kFinished = 0x40; 774 775 template<class Arc, class Queue> 776 void ShortestPath(const Fst<Arc> &ifst, 777 const vector<pair<typename Arc::Label, 778 typename Arc::Label> > &parens, 779 MutableFst<Arc> *ofst, 780 const PdtShortestPathOptions<Arc, Queue> &opts) { 781 PdtShortestPath<Arc, Queue> psp(ifst, parens, opts); 782 psp.ShortestPath(ofst); 783 } 784 785 template<class Arc> 786 void ShortestPath(const Fst<Arc> &ifst, 787 const vector<pair<typename Arc::Label, 788 typename Arc::Label> > &parens, 789 MutableFst<Arc> *ofst) { 790 typedef FifoQueue<typename Arc::StateId> Queue; 791 PdtShortestPathOptions<Arc, Queue> opts; 792 PdtShortestPath<Arc, Queue> psp(ifst, parens, opts); 793 psp.ShortestPath(ofst); 794 } 795 796 } // namespace fst 797 798 #endif // FST_EXTENSIONS_PDT_SHORTEST_PATH_H__ 799