1 // shortest-path.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley (at) google.com (Michael Riley) 17 // 18 // \file 19 // Functions to find shortest paths in a PDT. 20 21 #ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H__ 22 #define FST_EXTENSIONS_PDT_SHORTEST_PATH_H__ 23 24 #include <fst/shortest-path.h> 25 #include <fst/extensions/pdt/paren.h> 26 #include <fst/extensions/pdt/pdt.h> 27 28 #include <unordered_map> 29 using std::tr1::unordered_map; 30 using std::tr1::unordered_multimap; 31 #include <tr1/unordered_set> 32 using std::tr1::unordered_set; 33 using std::tr1::unordered_multiset; 34 #include <stack> 35 #include <vector> 36 using std::vector; 37 38 namespace fst { 39 40 template <class Arc, class Queue> 41 struct PdtShortestPathOptions { 42 bool keep_parentheses; 43 bool path_gc; 44 45 PdtShortestPathOptions(bool kp = false, bool gc = true) 46 : keep_parentheses(kp), path_gc(gc) {} 47 }; 48 49 50 // Class to store PDT shortest path results. Stores shortest path 51 // tree info 'Distance()', Parent(), and ArcParent() information keyed 52 // on two types: 53 // (1) By SearchState: This is a usual node in a shortest path tree but: 54 // (a) is w.r.t a PDT search state - a pair of a PDT state and 55 // a 'start' state, which is either the PDT start state or 56 // the destination state of an open parenthesis. 57 // (b) the Distance() is from this 'start' state to the search state. 58 // (c) Parent().state is kNoLabel for the 'start' state. 59 // 60 // (2) By ParenSpec: This connects shortest path trees depending on the 61 // the parenthesis taken. Given the parenthesis spec: 62 // (a) the Distance() is from the Parent() 'start' state to the 63 // parenthesis destination state. 64 // (b) the ArcParent() is the parenthesis arc. 65 template <class Arc> 66 class PdtShortestPathData { 67 public: 68 static const uint8 kFinal; 69 70 typedef typename Arc::StateId StateId; 71 typedef typename Arc::Weight Weight; 72 typedef typename Arc::Label Label; 73 74 struct SearchState { 75 SearchState() : state(kNoStateId), start(kNoStateId) {} 76 77 SearchState(StateId s, StateId t) : state(s), start(t) {} 78 79 bool operator==(const SearchState &s) const { 80 if (&s == this) 81 return true; 82 return s.state == this->state && s.start == this->start; 83 } 84 85 StateId state; // PDT state 86 StateId start; // PDT paren 'source' state 87 }; 88 89 90 // Specifies paren id, source and dest 'start' states of a paren. 91 // These are the 'start' states of the respective sub-graphs. 92 struct ParenSpec { 93 ParenSpec() 94 : paren_id(kNoLabel), src_start(kNoStateId), dest_start(kNoStateId) {} 95 96 ParenSpec(Label id, StateId s, StateId d) 97 : paren_id(id), src_start(s), dest_start(d) {} 98 99 Label paren_id; // Id of parenthesis 100 StateId src_start; // sub-graph 'start' state for paren source. 101 StateId dest_start; // sub-graph 'start' state for paren dest. 102 103 bool operator==(const ParenSpec &x) const { 104 if (&x == this) 105 return true; 106 return x.paren_id == this->paren_id && 107 x.src_start == this->src_start && 108 x.dest_start == this->dest_start; 109 } 110 }; 111 112 struct SearchData { 113 SearchData() : distance(Weight::Zero()), 114 parent(kNoStateId, kNoStateId), 115 paren_id(kNoLabel), 116 flags(0) {} 117 118 Weight distance; // Distance to this state from PDT 'start' state 119 SearchState parent; // Parent state in shortest path tree 120 int16 paren_id; // If parent arc has paren, paren ID, o.w. kNoLabel 121 uint8 flags; // First byte reserved for PdtShortestPathData use 122 }; 123 124 PdtShortestPathData(bool gc) 125 : state_(kNoStateId, kNoStateId), 126 paren_(kNoLabel, kNoStateId, kNoStateId), 127 gc_(gc), 128 nstates_(0), 129 ngc_(0), 130 finished_(false) {} 131 132 ~PdtShortestPathData() { 133 VLOG(1) << "opm size: " << paren_map_.size(); 134 VLOG(1) << "# of search states: " << nstates_; 135 if (gc_) 136 VLOG(1) << "# of GC'd search states: " << ngc_; 137 } 138 139 void Clear() { 140 search_map_.clear(); 141 search_multimap_.clear(); 142 paren_map_.clear(); 143 state_ = SearchState(kNoStateId, kNoStateId); 144 nstates_ = 0; 145 ngc_ = 0; 146 } 147 148 Weight Distance(SearchState s) const { 149 SearchData *data = GetSearchData(s); 150 return data->distance; 151 } 152 153 Weight Distance(const ParenSpec &paren) const { 154 SearchData *data = GetSearchData(paren); 155 return data->distance; 156 } 157 158 SearchState Parent(SearchState s) const { 159 SearchData *data = GetSearchData(s); 160 return data->parent; 161 } 162 163 SearchState Parent(const ParenSpec &paren) const { 164 SearchData *data = GetSearchData(paren); 165 return data->parent; 166 } 167 168 Label ParenId(SearchState s) const { 169 SearchData *data = GetSearchData(s); 170 return data->paren_id; 171 } 172 173 uint8 Flags(SearchState s) const { 174 SearchData *data = GetSearchData(s); 175 return data->flags; 176 } 177 178 void SetDistance(SearchState s, Weight w) { 179 SearchData *data = GetSearchData(s); 180 data->distance = w; 181 } 182 183 void SetDistance(const ParenSpec &paren, Weight w) { 184 SearchData *data = GetSearchData(paren); 185 data->distance = w; 186 } 187 188 void SetParent(SearchState s, SearchState p) { 189 SearchData *data = GetSearchData(s); 190 data->parent = p; 191 } 192 193 void SetParent(const ParenSpec &paren, SearchState p) { 194 SearchData *data = GetSearchData(paren); 195 data->parent = p; 196 } 197 198 void SetParenId(SearchState s, Label p) { 199 if (p >= 32768) 200 FSTERROR() << "PdtShortestPathData: Paren ID does not fits in an int16"; 201 SearchData *data = GetSearchData(s); 202 data->paren_id = p; 203 } 204 205 void SetFlags(SearchState s, uint8 f, uint8 mask) { 206 SearchData *data = GetSearchData(s); 207 data->flags &= ~mask; 208 data->flags |= f & mask; 209 } 210 211 void GC(StateId s); 212 213 void Finish() { finished_ = true; } 214 215 private: 216 static const Arc kNoArc; 217 static const size_t kPrime0; 218 static const size_t kPrime1; 219 static const uint8 kInited; 220 static const uint8 kMarked; 221 222 // Hash for search state 223 struct SearchStateHash { 224 size_t operator()(const SearchState &s) const { 225 return s.state + s.start * kPrime0; 226 } 227 }; 228 229 // Hash for paren map 230 struct ParenHash { 231 size_t operator()(const ParenSpec &paren) const { 232 return paren.paren_id + paren.src_start * kPrime0 + 233 paren.dest_start * kPrime1; 234 } 235 }; 236 237 typedef unordered_map<SearchState, SearchData, SearchStateHash> SearchMap; 238 239 typedef unordered_multimap<StateId, StateId> SearchMultimap; 240 241 // Hash map from paren spec to open paren data 242 typedef unordered_map<ParenSpec, SearchData, ParenHash> ParenMap; 243 244 SearchData *GetSearchData(SearchState s) const { 245 if (s == state_) 246 return state_data_; 247 if (finished_) { 248 typename SearchMap::iterator it = search_map_.find(s); 249 if (it == search_map_.end()) 250 return &null_search_data_; 251 state_ = s; 252 return state_data_ = &(it->second); 253 } else { 254 state_ = s; 255 state_data_ = &search_map_[s]; 256 if (!(state_data_->flags & kInited)) { 257 ++nstates_; 258 if (gc_) 259 search_multimap_.insert(make_pair(s.start, s.state)); 260 state_data_->flags = kInited; 261 } 262 return state_data_; 263 } 264 } 265 266 SearchData *GetSearchData(ParenSpec paren) const { 267 if (paren == paren_) 268 return paren_data_; 269 if (finished_) { 270 typename ParenMap::iterator it = paren_map_.find(paren); 271 if (it == paren_map_.end()) 272 return &null_search_data_; 273 paren_ = paren; 274 return state_data_ = &(it->second); 275 } else { 276 paren_ = paren; 277 return paren_data_ = &paren_map_[paren]; 278 } 279 } 280 281 mutable SearchMap search_map_; // Maps from search state to data 282 mutable SearchMultimap search_multimap_; // Maps from 'start' to subgraph 283 mutable ParenMap paren_map_; // Maps paren spec to search data 284 mutable SearchState state_; // Last state accessed 285 mutable SearchData *state_data_; // Last state data accessed 286 mutable ParenSpec paren_; // Last paren spec accessed 287 mutable SearchData *paren_data_; // Last paren data accessed 288 bool gc_; // Allow GC? 289 mutable size_t nstates_; // Total number of search states 290 size_t ngc_; // Number of GC'd search states 291 mutable SearchData null_search_data_; // Null search data 292 bool finished_; // Read-only access when true 293 294 DISALLOW_COPY_AND_ASSIGN(PdtShortestPathData); 295 }; 296 297 // Deletes inaccessible search data from a given 'start' (open paren dest) 298 // state. Assumes 'final' (close paren source or PDT final) states have 299 // been flagged 'kFinal'. 300 template<class Arc> 301 void PdtShortestPathData<Arc>::GC(StateId start) { 302 if (!gc_) 303 return; 304 vector<StateId> final; 305 for (typename SearchMultimap::iterator mmit = search_multimap_.find(start); 306 mmit != search_multimap_.end() && mmit->first == start; 307 ++mmit) { 308 SearchState s(mmit->second, start); 309 const SearchData &data = search_map_[s]; 310 if (data.flags & kFinal) 311 final.push_back(s.state); 312 } 313 314 // Mark phase 315 for (size_t i = 0; i < final.size(); ++i) { 316 SearchState s(final[i], start); 317 while (s.state != kNoLabel) { 318 SearchData *sdata = &search_map_[s]; 319 if (sdata->flags & kMarked) 320 break; 321 sdata->flags |= kMarked; 322 SearchState p = sdata->parent; 323 if (p.start != start && p.start != kNoLabel) { // entering sub-subgraph 324 ParenSpec paren(sdata->paren_id, s.start, p.start); 325 SearchData *pdata = &paren_map_[paren]; 326 s = pdata->parent; 327 } else { 328 s = p; 329 } 330 } 331 } 332 333 // Sweep phase 334 typename SearchMultimap::iterator mmit = search_multimap_.find(start); 335 while (mmit != search_multimap_.end() && mmit->first == start) { 336 SearchState s(mmit->second, start); 337 typename SearchMap::iterator mit = search_map_.find(s); 338 const SearchData &data = mit->second; 339 if (!(data.flags & kMarked)) { 340 search_map_.erase(mit); 341 ++ngc_; 342 } 343 search_multimap_.erase(mmit++); 344 } 345 } 346 347 template<class Arc> const Arc PdtShortestPathData<Arc>::kNoArc 348 = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId); 349 350 template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime0 = 7853; 351 352 template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime1 = 7867; 353 354 template<class Arc> const uint8 PdtShortestPathData<Arc>::kInited = 0x01; 355 356 template<class Arc> const uint8 PdtShortestPathData<Arc>::kFinal = 0x02; 357 358 template<class Arc> const uint8 PdtShortestPathData<Arc>::kMarked = 0x04; 359 360 361 // This computes the single source shortest (balanced) path (SSSP) 362 // through a weighted PDT that has a bounded stack (i.e. is expandable 363 // as an FST). It is a generalization of the classic SSSP graph 364 // algorithm that removes a state s from a queue (defined by a 365 // user-provided queue type) and relaxes the destination states of 366 // transitions leaving s. In this PDT version, states that have 367 // entering open parentheses are treated as source states for a 368 // sub-graph SSSP problem with the shortest path up to the open 369 // parenthesis being first saved. When a close parenthesis is then 370 // encountered any balancing open parenthesis is examined for this 371 // saved information and multiplied back. In this way, each sub-graph 372 // is entered only once rather than repeatedly. If every state in the 373 // input PDT has the property that there is a unique 'start' state for 374 // it with entering open parentheses, then this algorithm is quite 375 // straight-forward. In general, this will not be the case, so the 376 // algorithm (implicitly) creates a new graph where each state is a 377 // pair of an original state and a possible parenthesis 'start' state 378 // for that state. 379 template<class Arc, class Queue> 380 class PdtShortestPath { 381 public: 382 typedef typename Arc::StateId StateId; 383 typedef typename Arc::Weight Weight; 384 typedef typename Arc::Label Label; 385 386 typedef PdtShortestPathData<Arc> SpData; 387 typedef typename SpData::SearchState SearchState; 388 typedef typename SpData::ParenSpec ParenSpec; 389 390 typedef typename PdtParenReachable<Arc>::SetIterator StateSetIterator; 391 typedef typename PdtBalanceData<Arc>::SetIterator CloseSourceIterator; 392 393 PdtShortestPath(const Fst<Arc> &ifst, 394 const vector<pair<Label, Label> > &parens, 395 const PdtShortestPathOptions<Arc, Queue> &opts) 396 : kFinal(SpData::kFinal), 397 ifst_(ifst.Copy()), 398 parens_(parens), 399 keep_parens_(opts.keep_parentheses), 400 start_(ifst.Start()), 401 sp_data_(opts.path_gc), 402 error_(false) { 403 404 if ((Weight::Properties() & (kPath | kRightSemiring)) 405 != (kPath | kRightSemiring)) { 406 FSTERROR() << "SingleShortestPath: Weight needs to have the path" 407 << " property and be right distributive: " << Weight::Type(); 408 error_ = true; 409 } 410 411 for (Label i = 0; i < parens.size(); ++i) { 412 const pair<Label, Label> &p = parens[i]; 413 paren_id_map_[p.first] = i; 414 paren_id_map_[p.second] = i; 415 } 416 }; 417 418 ~PdtShortestPath() { 419 VLOG(1) << "# of input states: " << CountStates(*ifst_); 420 VLOG(1) << "# of enqueued: " << nenqueued_; 421 VLOG(1) << "cpmm size: " << close_paren_multimap_.size(); 422 delete ifst_; 423 } 424 425 void ShortestPath(MutableFst<Arc> *ofst) { 426 Init(ofst); 427 GetDistance(start_); 428 GetPath(); 429 sp_data_.Finish(); 430 if (error_) ofst->SetProperties(kError, kError); 431 } 432 433 const PdtShortestPathData<Arc> &GetShortestPathData() const { 434 return sp_data_; 435 } 436 437 PdtBalanceData<Arc> *GetBalanceData() { return &balance_data_; } 438 439 private: 440 static const Arc kNoArc; 441 static const uint8 kEnqueued; 442 static const uint8 kExpanded; 443 const uint8 kFinal; 444 445 public: 446 // Hash multimap from close paren label to an paren arc. 447 typedef unordered_multimap<ParenState<Arc>, Arc, 448 typename ParenState<Arc>::Hash> CloseParenMultimap; 449 450 const CloseParenMultimap &GetCloseParenMultimap() const { 451 return close_paren_multimap_; 452 } 453 454 private: 455 void Init(MutableFst<Arc> *ofst); 456 void GetDistance(StateId start); 457 void ProcFinal(SearchState s); 458 void ProcArcs(SearchState s); 459 void ProcOpenParen(Label paren_id, SearchState s, Arc arc, Weight w); 460 void ProcCloseParen(Label paren_id, SearchState s, const Arc &arc, Weight w); 461 void ProcNonParen(SearchState s, const Arc &arc, Weight w); 462 void Relax(SearchState s, SearchState t, Arc arc, Weight w, Label paren_id); 463 void Enqueue(SearchState d); 464 void GetPath(); 465 Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open); 466 467 Fst<Arc> *ifst_; 468 MutableFst<Arc> *ofst_; 469 const vector<pair<Label, Label> > &parens_; 470 bool keep_parens_; 471 Queue *state_queue_; // current state queue 472 StateId start_; 473 Weight f_distance_; 474 SearchState f_parent_; 475 SpData sp_data_; 476 unordered_map<Label, Label> paren_id_map_; 477 CloseParenMultimap close_paren_multimap_; 478 PdtBalanceData<Arc> balance_data_; 479 ssize_t nenqueued_; 480 bool error_; 481 482 DISALLOW_COPY_AND_ASSIGN(PdtShortestPath); 483 }; 484 485 template<class Arc, class Queue> 486 void PdtShortestPath<Arc, Queue>::Init(MutableFst<Arc> *ofst) { 487 ofst_ = ofst; 488 ofst->DeleteStates(); 489 ofst->SetInputSymbols(ifst_->InputSymbols()); 490 ofst->SetOutputSymbols(ifst_->OutputSymbols()); 491 492 if (ifst_->Start() == kNoStateId) 493 return; 494 495 f_distance_ = Weight::Zero(); 496 f_parent_ = SearchState(kNoStateId, kNoStateId); 497 498 sp_data_.Clear(); 499 close_paren_multimap_.clear(); 500 balance_data_.Clear(); 501 nenqueued_ = 0; 502 503 // Find open parens per destination state and close parens per source state. 504 for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) { 505 StateId s = siter.Value(); 506 for (ArcIterator<Fst<Arc> > aiter(*ifst_, s); 507 !aiter.Done(); aiter.Next()) { 508 const Arc &arc = aiter.Value(); 509 typename unordered_map<Label, Label>::const_iterator pit 510 = paren_id_map_.find(arc.ilabel); 511 if (pit != paren_id_map_.end()) { // Is a paren? 512 Label paren_id = pit->second; 513 if (arc.ilabel == parens_[paren_id].first) { // Open paren 514 balance_data_.OpenInsert(paren_id, arc.nextstate); 515 } else { // Close paren 516 ParenState<Arc> paren_state(paren_id, s); 517 close_paren_multimap_.insert(make_pair(paren_state, arc)); 518 } 519 } 520 } 521 } 522 } 523 524 // Computes the shortest distance stored in a recursive way. Each 525 // sub-graph (i.e. different paren 'start' state) begins with weight One(). 526 template<class Arc, class Queue> 527 void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) { 528 if (start == kNoStateId) 529 return; 530 531 Queue state_queue; 532 state_queue_ = &state_queue; 533 SearchState q(start, start); 534 Enqueue(q); 535 sp_data_.SetDistance(q, Weight::One()); 536 537 while (!state_queue_->Empty()) { 538 StateId state = state_queue_->Head(); 539 state_queue_->Dequeue(); 540 SearchState s(state, start); 541 sp_data_.SetFlags(s, 0, kEnqueued); 542 ProcFinal(s); 543 ProcArcs(s); 544 sp_data_.SetFlags(s, kExpanded, kExpanded); 545 } 546 balance_data_.FinishInsert(start); 547 sp_data_.GC(start); 548 } 549 550 // Updates best complete path. 551 template<class Arc, class Queue> 552 void PdtShortestPath<Arc, Queue>::ProcFinal(SearchState s) { 553 if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) { 554 Weight w = Times(sp_data_.Distance(s), 555 ifst_->Final(s.state)); 556 if (f_distance_ != Plus(f_distance_, w)) { 557 if (f_parent_.state != kNoStateId) 558 sp_data_.SetFlags(f_parent_, 0, kFinal); 559 sp_data_.SetFlags(s, kFinal, kFinal); 560 561 f_distance_ = Plus(f_distance_, w); 562 f_parent_ = s; 563 } 564 } 565 } 566 567 // Processes all arcs leaving the state s. 568 template<class Arc, class Queue> 569 void PdtShortestPath<Arc, Queue>::ProcArcs(SearchState s) { 570 for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state); 571 !aiter.Done(); 572 aiter.Next()) { 573 Arc arc = aiter.Value(); 574 Weight w = Times(sp_data_.Distance(s), arc.weight); 575 576 typename unordered_map<Label, Label>::const_iterator pit 577 = paren_id_map_.find(arc.ilabel); 578 if (pit != paren_id_map_.end()) { // Is a paren? 579 Label paren_id = pit->second; 580 if (arc.ilabel == parens_[paren_id].first) 581 ProcOpenParen(paren_id, s, arc, w); 582 else 583 ProcCloseParen(paren_id, s, arc, w); 584 } else { 585 ProcNonParen(s, arc, w); 586 } 587 } 588 } 589 590 // Saves the shortest path info for reaching this parenthesis 591 // and starts a new SSSP in the sub-graph pointed to by the parenthesis 592 // if previously unvisited. Otherwise it finds any previously encountered 593 // closing parentheses and relaxes them using the recursively stored 594 // shortest distance to them. 595 template<class Arc, class Queue> inline 596 void PdtShortestPath<Arc, Queue>::ProcOpenParen( 597 Label paren_id, SearchState s, Arc arc, Weight w) { 598 599 SearchState d(arc.nextstate, arc.nextstate); 600 ParenSpec paren(paren_id, s.start, d.start); 601 Weight pdist = sp_data_.Distance(paren); 602 if (pdist != Plus(pdist, w)) { 603 sp_data_.SetDistance(paren, w); 604 sp_data_.SetParent(paren, s); 605 Weight dist = sp_data_.Distance(d); 606 if (dist == Weight::Zero()) { 607 Queue *state_queue = state_queue_; 608 GetDistance(d.start); 609 state_queue_ = state_queue; 610 } 611 for (CloseSourceIterator set_iter = 612 balance_data_.Find(paren_id, arc.nextstate); 613 !set_iter.Done(); set_iter.Next()) { 614 SearchState cpstate(set_iter.Element(), d.start); 615 ParenState<Arc> paren_state(paren_id, cpstate.state); 616 for (typename CloseParenMultimap::const_iterator cpit = 617 close_paren_multimap_.find(paren_state); 618 cpit != close_paren_multimap_.end() && paren_state == cpit->first; 619 ++cpit) { 620 const Arc &cparc = cpit->second; 621 Weight cpw = Times(w, Times(sp_data_.Distance(cpstate), 622 cparc.weight)); 623 Relax(cpstate, s, cparc, cpw, paren_id); 624 } 625 } 626 } 627 } 628 629 // Saves the correspondence between each closing parenthesis and its 630 // balancing open parenthesis info. Relaxes any close parenthesis 631 // destination state that has a balancing previously encountered open 632 // parenthesis. 633 template<class Arc, class Queue> inline 634 void PdtShortestPath<Arc, Queue>::ProcCloseParen( 635 Label paren_id, SearchState s, const Arc &arc, Weight w) { 636 ParenState<Arc> paren_state(paren_id, s.start); 637 if (!(sp_data_.Flags(s) & kExpanded)) { 638 balance_data_.CloseInsert(paren_id, s.start, s.state); 639 sp_data_.SetFlags(s, kFinal, kFinal); 640 } 641 } 642 643 // For non-parentheses, classical relaxation. 644 template<class Arc, class Queue> inline 645 void PdtShortestPath<Arc, Queue>::ProcNonParen( 646 SearchState s, const Arc &arc, Weight w) { 647 Relax(s, s, arc, w, kNoLabel); 648 } 649 650 // Classical relaxation on the search graph for 'arc' from state 's'. 651 // State 't' is in the same sub-graph as the nextstate should be (i.e. 652 // has the same paren 'start'. 653 template<class Arc, class Queue> inline 654 void PdtShortestPath<Arc, Queue>::Relax( 655 SearchState s, SearchState t, Arc arc, Weight w, Label paren_id) { 656 SearchState d(arc.nextstate, t.start); 657 Weight dist = sp_data_.Distance(d); 658 if (dist != Plus(dist, w)) { 659 sp_data_.SetParent(d, s); 660 sp_data_.SetParenId(d, paren_id); 661 sp_data_.SetDistance(d, Plus(dist, w)); 662 Enqueue(d); 663 } 664 } 665 666 template<class Arc, class Queue> inline 667 void PdtShortestPath<Arc, Queue>::Enqueue(SearchState s) { 668 if (!(sp_data_.Flags(s) & kEnqueued)) { 669 state_queue_->Enqueue(s.state); 670 sp_data_.SetFlags(s, kEnqueued, kEnqueued); 671 ++nenqueued_; 672 } else { 673 state_queue_->Update(s.state); 674 } 675 } 676 677 // Follows parent pointers to find the shortest path. Uses a stack 678 // since the shortest distance is stored recursively. 679 template<class Arc, class Queue> 680 void PdtShortestPath<Arc, Queue>::GetPath() { 681 SearchState s = f_parent_, d = SearchState(kNoStateId, kNoStateId); 682 StateId s_p = kNoStateId, d_p = kNoStateId; 683 Arc arc(kNoArc); 684 Label paren_id = kNoLabel; 685 stack<ParenSpec> paren_stack; 686 while (s.state != kNoStateId) { 687 d_p = s_p; 688 s_p = ofst_->AddState(); 689 if (d.state == kNoStateId) { 690 ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state)); 691 } else { 692 if (paren_id != kNoLabel) { // paren? 693 if (arc.ilabel == parens_[paren_id].first) { // open paren 694 paren_stack.pop(); 695 } else { // close paren 696 ParenSpec paren(paren_id, d.start, s.start); 697 paren_stack.push(paren); 698 } 699 if (!keep_parens_) 700 arc.ilabel = arc.olabel = 0; 701 } 702 arc.nextstate = d_p; 703 ofst_->AddArc(s_p, arc); 704 } 705 d = s; 706 s = sp_data_.Parent(d); 707 paren_id = sp_data_.ParenId(d); 708 if (s.state != kNoStateId) { 709 arc = GetPathArc(s, d, paren_id, false); 710 } else if (!paren_stack.empty()) { 711 ParenSpec paren = paren_stack.top(); 712 s = sp_data_.Parent(paren); 713 paren_id = paren.paren_id; 714 arc = GetPathArc(s, d, paren_id, true); 715 } 716 } 717 ofst_->SetStart(s_p); 718 ofst_->SetProperties( 719 ShortestPathProperties(ofst_->Properties(kFstProperties, false)), 720 kFstProperties); 721 } 722 723 724 // Finds transition with least weight between two states with label matching 725 // paren_id and open/close paren type or a non-paren if kNoLabel. 726 template<class Arc, class Queue> 727 Arc PdtShortestPath<Arc, Queue>::GetPathArc( 728 SearchState s, SearchState d, Label paren_id, bool open_paren) { 729 Arc path_arc = kNoArc; 730 for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state); 731 !aiter.Done(); 732 aiter.Next()) { 733 const Arc &arc = aiter.Value(); 734 if (arc.nextstate != d.state) 735 continue; 736 Label arc_paren_id = kNoLabel; 737 typename unordered_map<Label, Label>::const_iterator pit 738 = paren_id_map_.find(arc.ilabel); 739 if (pit != paren_id_map_.end()) { 740 arc_paren_id = pit->second; 741 bool arc_open_paren = arc.ilabel == parens_[arc_paren_id].first; 742 if (arc_open_paren != open_paren) 743 continue; 744 } 745 if (arc_paren_id != paren_id) 746 continue; 747 if (arc.weight == Plus(arc.weight, path_arc.weight)) 748 path_arc = arc; 749 } 750 if (path_arc.nextstate == kNoStateId) { 751 FSTERROR() << "PdtShortestPath::GetPathArc failed to find arc"; 752 error_ = true; 753 } 754 return path_arc; 755 } 756 757 template<class Arc, class Queue> 758 const Arc PdtShortestPath<Arc, Queue>::kNoArc 759 = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId); 760 761 template<class Arc, class Queue> 762 const uint8 PdtShortestPath<Arc, Queue>::kEnqueued = 0x10; 763 764 template<class Arc, class Queue> 765 const uint8 PdtShortestPath<Arc, Queue>::kExpanded = 0x20; 766 767 template<class Arc, class Queue> 768 void ShortestPath(const Fst<Arc> &ifst, 769 const vector<pair<typename Arc::Label, 770 typename Arc::Label> > &parens, 771 MutableFst<Arc> *ofst, 772 const PdtShortestPathOptions<Arc, Queue> &opts) { 773 PdtShortestPath<Arc, Queue> psp(ifst, parens, opts); 774 psp.ShortestPath(ofst); 775 } 776 777 template<class Arc> 778 void ShortestPath(const Fst<Arc> &ifst, 779 const vector<pair<typename Arc::Label, 780 typename Arc::Label> > &parens, 781 MutableFst<Arc> *ofst) { 782 typedef FifoQueue<typename Arc::StateId> Queue; 783 PdtShortestPathOptions<Arc, Queue> opts; 784 PdtShortestPath<Arc, Queue> psp(ifst, parens, opts); 785 psp.ShortestPath(ofst); 786 } 787 788 } // namespace fst 789 790 #endif // FST_EXTENSIONS_PDT_SHORTEST_PATH_H__ 791