Home | History | Annotate | Download | only in pdt
      1 // shortest-path.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Functions to find shortest paths in a PDT.
     20 
     21 #ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
     22 #define FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
     23 
     24 #include <fst/shortest-path.h>
     25 #include <fst/extensions/pdt/paren.h>
     26 #include <fst/extensions/pdt/pdt.h>
     27 
     28 #include <tr1/unordered_map>
     29 using std::tr1::unordered_map;
     30 using std::tr1::unordered_multimap;
     31 #include <tr1/unordered_set>
     32 using std::tr1::unordered_set;
     33 using std::tr1::unordered_multiset;
     34 #include <stack>
     35 #include <vector>
     36 using std::vector;
     37 
     38 namespace fst {
     39 
     40 template <class Arc, class Queue>
     41 struct PdtShortestPathOptions {
     42   bool keep_parentheses;
     43   bool path_gc;
     44 
     45   PdtShortestPathOptions(bool kp = false, bool gc = true)
     46       : keep_parentheses(kp), path_gc(gc) {}
     47 };
     48 
     49 
     50 // Class to store PDT shortest path results. Stores shortest path
     51 // tree info 'Distance()', Parent(), and ArcParent() information keyed
     52 // on two types:
     53 // (1) By SearchState: This is a usual node in a shortest path tree but:
     54 //    (a) is w.r.t a PDT search state - a pair of a PDT state and
     55 //        a 'start' state, which is either the PDT start state or
     56 //        the destination state of an open parenthesis.
     57 //    (b) the Distance() is from this 'start' state to the search state.
     58 //    (c) Parent().state is kNoLabel for the 'start' state.
     59 //
     60 // (2) By ParenSpec: This connects shortest path trees depending on the
     61 // the parenthesis taken. Given the parenthesis spec:
     62 //    (a) the Distance() is from the Parent() 'start' state to the
     63 //     parenthesis destination state.
     64 //    (b) the ArcParent() is the parenthesis arc.
     65 template <class Arc>
     66 class PdtShortestPathData {
     67  public:
     68   static const uint8 kFinal;
     69 
     70   typedef typename Arc::StateId StateId;
     71   typedef typename Arc::Weight Weight;
     72   typedef typename Arc::Label Label;
     73 
     74   struct SearchState {
     75     SearchState() : state(kNoStateId), start(kNoStateId) {}
     76 
     77     SearchState(StateId s, StateId t) : state(s), start(t) {}
     78 
     79     bool operator==(const SearchState &s) const {
     80       if (&s == this)
     81         return true;
     82       return s.state == this->state && s.start == this->start;
     83     }
     84 
     85     StateId state;  // PDT state
     86     StateId start;  // PDT paren 'source' state
     87   };
     88 
     89 
     90   // Specifies paren id, source and dest 'start' states of a paren.
     91   // These are the 'start' states of the respective sub-graphs.
     92   struct ParenSpec {
     93     ParenSpec()
     94         : paren_id(kNoLabel), src_start(kNoStateId), dest_start(kNoStateId) {}
     95 
     96     ParenSpec(Label id, StateId s, StateId d)
     97         : paren_id(id), src_start(s), dest_start(d) {}
     98 
     99     Label paren_id;        // Id of parenthesis
    100     StateId src_start;     // sub-graph 'start' state for paren source.
    101     StateId dest_start;    // sub-graph 'start' state for paren dest.
    102 
    103     bool operator==(const ParenSpec &x) const {
    104       if (&x == this)
    105         return true;
    106       return x.paren_id == this->paren_id &&
    107           x.src_start == this->src_start &&
    108           x.dest_start == this->dest_start;
    109     }
    110   };
    111 
    112   struct SearchData {
    113     SearchData() : distance(Weight::Zero()),
    114                    parent(kNoStateId, kNoStateId),
    115                    paren_id(kNoLabel),
    116                    flags(0) {}
    117 
    118     Weight distance;     // Distance to this state from PDT 'start' state
    119     SearchState parent;  // Parent state in shortest path tree
    120     int16 paren_id;      // If parent arc has paren, paren ID, o.w. kNoLabel
    121     uint8 flags;         // First byte reserved for PdtShortestPathData use
    122   };
    123 
    124   PdtShortestPathData(bool gc)
    125       : state_(kNoStateId, kNoStateId),
    126         paren_(kNoLabel, kNoStateId, kNoStateId),
    127         gc_(gc),
    128         nstates_(0),
    129         ngc_(0),
    130         finished_(false) {}
    131 
    132   ~PdtShortestPathData() {
    133     VLOG(1) << "opm size: " << paren_map_.size();
    134     VLOG(1) << "# of search states: " << nstates_;
    135     if (gc_)
    136       VLOG(1) << "# of GC'd search states: " << ngc_;
    137   }
    138 
    139   void Clear() {
    140     search_map_.clear();
    141     search_multimap_.clear();
    142     paren_map_.clear();
    143     state_ = SearchState(kNoStateId, kNoStateId);
    144     nstates_ = 0;
    145     ngc_ = 0;
    146   }
    147 
    148   Weight Distance(SearchState s) const {
    149     SearchData *data = GetSearchData(s);
    150     return data->distance;
    151   }
    152 
    153   Weight Distance(const ParenSpec &paren) const {
    154     SearchData *data = GetSearchData(paren);
    155     return data->distance;
    156   }
    157 
    158   SearchState Parent(SearchState s) const {
    159     SearchData *data = GetSearchData(s);
    160     return data->parent;
    161   }
    162 
    163   SearchState Parent(const ParenSpec &paren) const {
    164     SearchData *data = GetSearchData(paren);
    165     return data->parent;
    166   }
    167 
    168   Label ParenId(SearchState s) const {
    169     SearchData *data = GetSearchData(s);
    170     return data->paren_id;
    171   }
    172 
    173   uint8 Flags(SearchState s) const {
    174     SearchData *data = GetSearchData(s);
    175     return data->flags;
    176   }
    177 
    178   void SetDistance(SearchState s, Weight w) {
    179     SearchData *data = GetSearchData(s);
    180     data->distance = w;
    181   }
    182 
    183   void SetDistance(const ParenSpec &paren, Weight w) {
    184     SearchData *data = GetSearchData(paren);
    185     data->distance = w;
    186   }
    187 
    188   void SetParent(SearchState s, SearchState p) {
    189     SearchData *data = GetSearchData(s);
    190     data->parent = p;
    191   }
    192 
    193   void SetParent(const ParenSpec &paren, SearchState p) {
    194     SearchData *data = GetSearchData(paren);
    195     data->parent = p;
    196   }
    197 
    198   void SetParenId(SearchState s, Label p) {
    199     if (p >= 32768)
    200       FSTERROR() << "PdtShortestPathData: Paren ID does not fits in an int16";
    201     SearchData *data = GetSearchData(s);
    202     data->paren_id = p;
    203   }
    204 
    205   void SetFlags(SearchState s, uint8 f, uint8 mask) {
    206     SearchData *data = GetSearchData(s);
    207     data->flags &= ~mask;
    208     data->flags |= f & mask;
    209   }
    210 
    211   void GC(StateId s);
    212 
    213   void Finish() { finished_ = true; }
    214 
    215  private:
    216   static const Arc kNoArc;
    217   static const size_t kPrime0;
    218   static const size_t kPrime1;
    219   static const uint8 kInited;
    220   static const uint8 kMarked;
    221 
    222   // Hash for search state
    223   struct SearchStateHash {
    224     size_t operator()(const SearchState &s) const {
    225       return s.state + s.start * kPrime0;
    226     }
    227   };
    228 
    229   // Hash for paren map
    230   struct ParenHash {
    231     size_t operator()(const ParenSpec &paren) const {
    232       return paren.paren_id + paren.src_start * kPrime0 +
    233           paren.dest_start * kPrime1;
    234     }
    235   };
    236 
    237   typedef unordered_map<SearchState, SearchData, SearchStateHash> SearchMap;
    238 
    239   typedef unordered_multimap<StateId, StateId> SearchMultimap;
    240 
    241   // Hash map from paren spec to open paren data
    242   typedef unordered_map<ParenSpec, SearchData, ParenHash> ParenMap;
    243 
    244   SearchData *GetSearchData(SearchState s) const {
    245     if (s == state_)
    246       return state_data_;
    247     if (finished_) {
    248       typename SearchMap::iterator it = search_map_.find(s);
    249       if (it == search_map_.end())
    250         return &null_search_data_;
    251       state_ = s;
    252       return state_data_ = &(it->second);
    253     } else {
    254       state_ = s;
    255       state_data_ = &search_map_[s];
    256       if (!(state_data_->flags & kInited)) {
    257         ++nstates_;
    258         if (gc_)
    259           search_multimap_.insert(make_pair(s.start, s.state));
    260         state_data_->flags = kInited;
    261       }
    262       return state_data_;
    263     }
    264   }
    265 
    266   SearchData *GetSearchData(ParenSpec paren) const {
    267     if (paren == paren_)
    268       return paren_data_;
    269     if (finished_) {
    270       typename ParenMap::iterator it = paren_map_.find(paren);
    271       if (it == paren_map_.end())
    272         return &null_search_data_;
    273       paren_ = paren;
    274       return state_data_ = &(it->second);
    275     } else {
    276       paren_ = paren;
    277       return paren_data_ = &paren_map_[paren];
    278     }
    279   }
    280 
    281   mutable SearchMap search_map_;            // Maps from search state to data
    282   mutable SearchMultimap search_multimap_;  // Maps from 'start' to subgraph
    283   mutable ParenMap paren_map_;              // Maps paren spec to search data
    284   mutable SearchState state_;               // Last state accessed
    285   mutable SearchData *state_data_;          // Last state data accessed
    286   mutable ParenSpec paren_;                 // Last paren spec accessed
    287   mutable SearchData *paren_data_;          // Last paren data accessed
    288   bool gc_;                                 // Allow GC?
    289   mutable size_t nstates_;                  // Total number of search states
    290   size_t ngc_;                              // Number of GC'd search states
    291   mutable SearchData null_search_data_;     // Null search data
    292   bool finished_;                           // Read-only access when true
    293 
    294   DISALLOW_COPY_AND_ASSIGN(PdtShortestPathData);
    295 };
    296 
    297 // Deletes inaccessible search data from a given 'start' (open paren dest)
    298 // state. Assumes 'final' (close paren source or PDT final) states have
    299 // been flagged 'kFinal'.
    300 template<class Arc>
    301 void  PdtShortestPathData<Arc>::GC(StateId start) {
    302   if (!gc_)
    303     return;
    304   vector<StateId> final;
    305   for (typename SearchMultimap::iterator mmit = search_multimap_.find(start);
    306        mmit != search_multimap_.end() && mmit->first == start;
    307        ++mmit) {
    308     SearchState s(mmit->second, start);
    309     const SearchData &data = search_map_[s];
    310     if (data.flags & kFinal)
    311       final.push_back(s.state);
    312   }
    313 
    314   // Mark phase
    315   for (size_t i = 0; i < final.size(); ++i) {
    316     SearchState s(final[i], start);
    317     while (s.state != kNoLabel) {
    318       SearchData *sdata = &search_map_[s];
    319       if (sdata->flags & kMarked)
    320         break;
    321       sdata->flags |= kMarked;
    322       SearchState p = sdata->parent;
    323       if (p.start != start && p.start != kNoLabel) {  // entering sub-subgraph
    324         ParenSpec paren(sdata->paren_id, s.start, p.start);
    325         SearchData *pdata = &paren_map_[paren];
    326         s = pdata->parent;
    327       } else {
    328         s = p;
    329       }
    330     }
    331   }
    332 
    333   // Sweep phase
    334   typename SearchMultimap::iterator mmit = search_multimap_.find(start);
    335   while (mmit != search_multimap_.end() && mmit->first == start) {
    336     SearchState s(mmit->second, start);
    337     typename SearchMap::iterator mit = search_map_.find(s);
    338     const SearchData &data = mit->second;
    339     if (!(data.flags & kMarked)) {
    340       search_map_.erase(mit);
    341       ++ngc_;
    342     }
    343     search_multimap_.erase(mmit++);
    344   }
    345 }
    346 
    347 template<class Arc> const Arc PdtShortestPathData<Arc>::kNoArc
    348     = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId);
    349 
    350 template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime0 = 7853;
    351 
    352 template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime1 = 7867;
    353 
    354 template<class Arc> const uint8 PdtShortestPathData<Arc>::kInited = 0x01;
    355 
    356 template<class Arc> const uint8 PdtShortestPathData<Arc>::kFinal =  0x02;
    357 
    358 template<class Arc> const uint8 PdtShortestPathData<Arc>::kMarked = 0x04;
    359 
    360 
    361 // This computes the single source shortest (balanced) path (SSSP)
    362 // through a weighted PDT that has a bounded stack (i.e. is expandable
    363 // as an FST). It is a generalization of the classic SSSP graph
    364 // algorithm that removes a state s from a queue (defined by a
    365 // user-provided queue type) and relaxes the destination states of
    366 // transitions leaving s. In this PDT version, states that have
    367 // entering open parentheses are treated as source states for a
    368 // sub-graph SSSP problem with the shortest path up to the open
    369 // parenthesis being first saved. When a close parenthesis is then
    370 // encountered any balancing open parenthesis is examined for this
    371 // saved information and multiplied back. In this way, each sub-graph
    372 // is entered only once rather than repeatedly.  If every state in the
    373 // input PDT has the property that there is a unique 'start' state for
    374 // it with entering open parentheses, then this algorithm is quite
    375 // straight-forward. In general, this will not be the case, so the
    376 // algorithm (implicitly) creates a new graph where each state is a
    377 // pair of an original state and a possible parenthesis 'start' state
    378 // for that state.
    379 template<class Arc, class Queue>
    380 class PdtShortestPath {
    381  public:
    382   typedef typename Arc::StateId StateId;
    383   typedef typename Arc::Weight Weight;
    384   typedef typename Arc::Label Label;
    385 
    386   typedef PdtShortestPathData<Arc> SpData;
    387   typedef typename SpData::SearchState SearchState;
    388   typedef typename SpData::ParenSpec ParenSpec;
    389 
    390   typedef typename PdtBalanceData<Arc>::SetIterator CloseSourceIterator;
    391 
    392   PdtShortestPath(const Fst<Arc> &ifst,
    393                   const vector<pair<Label, Label> > &parens,
    394                   const PdtShortestPathOptions<Arc, Queue> &opts)
    395       : kFinal(SpData::kFinal),
    396         ifst_(ifst.Copy()),
    397         parens_(parens),
    398         keep_parens_(opts.keep_parentheses),
    399         start_(ifst.Start()),
    400         sp_data_(opts.path_gc),
    401         error_(false) {
    402 
    403     if ((Weight::Properties() & (kPath | kRightSemiring))
    404         != (kPath | kRightSemiring)) {
    405       FSTERROR() << "PdtShortestPath: Weight needs to have the path"
    406                  << " property and be right distributive: " << Weight::Type();
    407       error_ = true;
    408     }
    409 
    410     for (Label i = 0; i < parens.size(); ++i) {
    411       const pair<Label, Label>  &p = parens[i];
    412       paren_id_map_[p.first] = i;
    413       paren_id_map_[p.second] = i;
    414     }
    415   };
    416 
    417   ~PdtShortestPath() {
    418     VLOG(1) << "# of input states: " << CountStates(*ifst_);
    419     VLOG(1) << "# of enqueued: " << nenqueued_;
    420     VLOG(1) << "cpmm size: " << close_paren_multimap_.size();
    421     delete ifst_;
    422   }
    423 
    424   void ShortestPath(MutableFst<Arc> *ofst) {
    425     Init(ofst);
    426     GetDistance(start_);
    427     GetPath();
    428     sp_data_.Finish();
    429     if (error_) ofst->SetProperties(kError, kError);
    430   }
    431 
    432   const PdtShortestPathData<Arc> &GetShortestPathData() const {
    433     return sp_data_;
    434   }
    435 
    436   PdtBalanceData<Arc> *GetBalanceData() { return &balance_data_; }
    437 
    438  private:
    439   static const Arc kNoArc;
    440   static const uint8 kEnqueued;
    441   static const uint8 kExpanded;
    442   static const uint8 kFinished;
    443   const uint8 kFinal;
    444 
    445  public:
    446   // Hash multimap from close paren label to an paren arc.
    447   typedef unordered_multimap<ParenState<Arc>, Arc,
    448                         typename ParenState<Arc>::Hash> CloseParenMultimap;
    449 
    450   const CloseParenMultimap &GetCloseParenMultimap() const {
    451     return close_paren_multimap_;
    452   }
    453 
    454  private:
    455   void Init(MutableFst<Arc> *ofst);
    456   void GetDistance(StateId start);
    457   void ProcFinal(SearchState s);
    458   void ProcArcs(SearchState s);
    459   void ProcOpenParen(Label paren_id, SearchState s, Arc arc, Weight w);
    460   void ProcCloseParen(Label paren_id, SearchState s, const Arc &arc, Weight w);
    461   void ProcNonParen(SearchState s, const Arc &arc, Weight w);
    462   void Relax(SearchState s, SearchState t, Arc arc, Weight w, Label paren_id);
    463   void Enqueue(SearchState d);
    464   void GetPath();
    465   Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open);
    466 
    467   Fst<Arc> *ifst_;
    468   MutableFst<Arc> *ofst_;
    469   const vector<pair<Label, Label> > &parens_;
    470   bool keep_parens_;
    471   Queue *state_queue_;                   // current state queue
    472   StateId start_;
    473   Weight f_distance_;
    474   SearchState f_parent_;
    475   SpData sp_data_;
    476   unordered_map<Label, Label> paren_id_map_;
    477   CloseParenMultimap close_paren_multimap_;
    478   PdtBalanceData<Arc> balance_data_;
    479   ssize_t nenqueued_;
    480   bool error_;
    481 
    482   DISALLOW_COPY_AND_ASSIGN(PdtShortestPath);
    483 };
    484 
    485 template<class Arc, class Queue>
    486 void PdtShortestPath<Arc, Queue>::Init(MutableFst<Arc> *ofst) {
    487   ofst_ = ofst;
    488   ofst->DeleteStates();
    489   ofst->SetInputSymbols(ifst_->InputSymbols());
    490   ofst->SetOutputSymbols(ifst_->OutputSymbols());
    491 
    492   if (ifst_->Start() == kNoStateId)
    493     return;
    494 
    495   f_distance_ = Weight::Zero();
    496   f_parent_ = SearchState(kNoStateId, kNoStateId);
    497 
    498   sp_data_.Clear();
    499   close_paren_multimap_.clear();
    500   balance_data_.Clear();
    501   nenqueued_ = 0;
    502 
    503   // Find open parens per destination state and close parens per source state.
    504   for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) {
    505     StateId s = siter.Value();
    506     for (ArcIterator<Fst<Arc> > aiter(*ifst_, s);
    507          !aiter.Done(); aiter.Next()) {
    508       const Arc &arc = aiter.Value();
    509       typename unordered_map<Label, Label>::const_iterator pit
    510           = paren_id_map_.find(arc.ilabel);
    511       if (pit != paren_id_map_.end()) {               // Is a paren?
    512         Label paren_id = pit->second;
    513         if (arc.ilabel == parens_[paren_id].first) {  // Open paren
    514           balance_data_.OpenInsert(paren_id, arc.nextstate);
    515         } else {                                      // Close paren
    516           ParenState<Arc> paren_state(paren_id, s);
    517           close_paren_multimap_.insert(make_pair(paren_state, arc));
    518         }
    519       }
    520     }
    521   }
    522 }
    523 
    524 // Computes the shortest distance stored in a recursive way. Each
    525 // sub-graph (i.e. different paren 'start' state) begins with weight One().
    526 template<class Arc, class Queue>
    527 void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) {
    528   if (start == kNoStateId)
    529     return;
    530 
    531   Queue state_queue;
    532   state_queue_ = &state_queue;
    533   SearchState q(start, start);
    534   Enqueue(q);
    535   sp_data_.SetDistance(q, Weight::One());
    536 
    537   while (!state_queue_->Empty()) {
    538     StateId state = state_queue_->Head();
    539     state_queue_->Dequeue();
    540     SearchState s(state, start);
    541     sp_data_.SetFlags(s, 0, kEnqueued);
    542     ProcFinal(s);
    543     ProcArcs(s);
    544     sp_data_.SetFlags(s, kExpanded, kExpanded);
    545   }
    546   sp_data_.SetFlags(q, kFinished, kFinished);
    547   balance_data_.FinishInsert(start);
    548   sp_data_.GC(start);
    549 }
    550 
    551 // Updates best complete path.
    552 template<class Arc, class Queue>
    553 void PdtShortestPath<Arc, Queue>::ProcFinal(SearchState s) {
    554   if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) {
    555     Weight w = Times(sp_data_.Distance(s),
    556                      ifst_->Final(s.state));
    557     if (f_distance_ != Plus(f_distance_, w)) {
    558       if (f_parent_.state != kNoStateId)
    559         sp_data_.SetFlags(f_parent_, 0, kFinal);
    560       sp_data_.SetFlags(s, kFinal, kFinal);
    561 
    562       f_distance_ = Plus(f_distance_, w);
    563       f_parent_ = s;
    564     }
    565   }
    566 }
    567 
    568 // Processes all arcs leaving the state s.
    569 template<class Arc, class Queue>
    570 void PdtShortestPath<Arc, Queue>::ProcArcs(SearchState s) {
    571   for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state);
    572        !aiter.Done();
    573        aiter.Next()) {
    574     Arc arc = aiter.Value();
    575     Weight w = Times(sp_data_.Distance(s), arc.weight);
    576 
    577     typename unordered_map<Label, Label>::const_iterator pit
    578         = paren_id_map_.find(arc.ilabel);
    579     if (pit != paren_id_map_.end()) {  // Is a paren?
    580       Label paren_id = pit->second;
    581       if (arc.ilabel == parens_[paren_id].first)
    582         ProcOpenParen(paren_id, s, arc, w);
    583       else
    584         ProcCloseParen(paren_id, s, arc, w);
    585     } else {
    586       ProcNonParen(s, arc, w);
    587     }
    588   }
    589 }
    590 
    591 // Saves the shortest path info for reaching this parenthesis
    592 // and starts a new SSSP in the sub-graph pointed to by the parenthesis
    593 // if previously unvisited. Otherwise it finds any previously encountered
    594 // closing parentheses and relaxes them using the recursively stored
    595 // shortest distance to them.
    596 template<class Arc, class Queue> inline
    597 void PdtShortestPath<Arc, Queue>::ProcOpenParen(
    598     Label paren_id, SearchState s, Arc arc, Weight w) {
    599 
    600   SearchState d(arc.nextstate, arc.nextstate);
    601   ParenSpec paren(paren_id, s.start, d.start);
    602   Weight pdist = sp_data_.Distance(paren);
    603   if (pdist != Plus(pdist, w)) {
    604     sp_data_.SetDistance(paren, w);
    605     sp_data_.SetParent(paren, s);
    606     Weight dist = sp_data_.Distance(d);
    607     if (dist == Weight::Zero()) {
    608       Queue *state_queue = state_queue_;
    609       GetDistance(d.start);
    610       state_queue_ = state_queue;
    611     } else if (!(sp_data_.Flags(d) & kFinished)) {
    612       FSTERROR() << "PdtShortestPath: open parenthesis recursion: not bounded stack";
    613       error_ = true;
    614     }
    615 
    616     for (CloseSourceIterator set_iter =
    617              balance_data_.Find(paren_id, arc.nextstate);
    618          !set_iter.Done(); set_iter.Next()) {
    619       SearchState cpstate(set_iter.Element(), d.start);
    620       ParenState<Arc> paren_state(paren_id, cpstate.state);
    621       for (typename CloseParenMultimap::const_iterator cpit =
    622                close_paren_multimap_.find(paren_state);
    623            cpit != close_paren_multimap_.end() && paren_state == cpit->first;
    624            ++cpit) {
    625         const Arc &cparc = cpit->second;
    626         Weight cpw = Times(w, Times(sp_data_.Distance(cpstate),
    627                                     cparc.weight));
    628         Relax(cpstate, s, cparc, cpw, paren_id);
    629       }
    630     }
    631   }
    632 }
    633 
    634 // Saves the correspondence between each closing parenthesis and its
    635 // balancing open parenthesis info. Relaxes any close parenthesis
    636 // destination state that has a balancing previously encountered open
    637 // parenthesis.
    638 template<class Arc, class Queue> inline
    639 void PdtShortestPath<Arc, Queue>::ProcCloseParen(
    640     Label paren_id, SearchState s, const Arc &arc, Weight w) {
    641   ParenState<Arc> paren_state(paren_id, s.start);
    642   if (!(sp_data_.Flags(s) & kExpanded)) {
    643     balance_data_.CloseInsert(paren_id, s.start, s.state);
    644     sp_data_.SetFlags(s, kFinal, kFinal);
    645   }
    646 }
    647 
    648 // For non-parentheses, classical relaxation.
    649 template<class Arc, class Queue> inline
    650 void PdtShortestPath<Arc, Queue>::ProcNonParen(
    651     SearchState s, const Arc &arc, Weight w) {
    652   Relax(s, s, arc, w, kNoLabel);
    653 }
    654 
    655 // Classical relaxation on the search graph for 'arc' from state 's'.
    656 // State 't' is in the same sub-graph as the nextstate should be (i.e.
    657 // has the same paren 'start'.
    658 template<class Arc, class Queue> inline
    659 void PdtShortestPath<Arc, Queue>::Relax(
    660     SearchState s, SearchState t, Arc arc, Weight w, Label paren_id) {
    661   SearchState d(arc.nextstate, t.start);
    662   Weight dist = sp_data_.Distance(d);
    663   if (dist != Plus(dist, w)) {
    664     sp_data_.SetParent(d, s);
    665     sp_data_.SetParenId(d, paren_id);
    666     sp_data_.SetDistance(d, Plus(dist, w));
    667     Enqueue(d);
    668   }
    669 }
    670 
    671 template<class Arc, class Queue> inline
    672 void PdtShortestPath<Arc, Queue>::Enqueue(SearchState s) {
    673   if (!(sp_data_.Flags(s) & kEnqueued)) {
    674     state_queue_->Enqueue(s.state);
    675     sp_data_.SetFlags(s, kEnqueued, kEnqueued);
    676     ++nenqueued_;
    677   } else {
    678     state_queue_->Update(s.state);
    679   }
    680 }
    681 
    682 // Follows parent pointers to find the shortest path. Uses a stack
    683 // since the shortest distance is stored recursively.
    684 template<class Arc, class Queue>
    685 void PdtShortestPath<Arc, Queue>::GetPath() {
    686   SearchState s = f_parent_, d = SearchState(kNoStateId, kNoStateId);
    687   StateId s_p = kNoStateId, d_p = kNoStateId;
    688   Arc arc(kNoArc);
    689   Label paren_id = kNoLabel;
    690   stack<ParenSpec> paren_stack;
    691   while (s.state != kNoStateId) {
    692     d_p = s_p;
    693     s_p = ofst_->AddState();
    694     if (d.state == kNoStateId) {
    695       ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state));
    696     } else {
    697       if (paren_id != kNoLabel) {                     // paren?
    698         if (arc.ilabel == parens_[paren_id].first) {  // open paren
    699           paren_stack.pop();
    700         } else {                                      // close paren
    701           ParenSpec paren(paren_id, d.start, s.start);
    702           paren_stack.push(paren);
    703         }
    704         if (!keep_parens_)
    705           arc.ilabel = arc.olabel = 0;
    706       }
    707       arc.nextstate = d_p;
    708       ofst_->AddArc(s_p, arc);
    709     }
    710     d = s;
    711     s = sp_data_.Parent(d);
    712     paren_id = sp_data_.ParenId(d);
    713     if (s.state != kNoStateId) {
    714       arc = GetPathArc(s, d, paren_id, false);
    715     } else if (!paren_stack.empty()) {
    716       ParenSpec paren = paren_stack.top();
    717       s = sp_data_.Parent(paren);
    718       paren_id = paren.paren_id;
    719       arc = GetPathArc(s, d, paren_id, true);
    720     }
    721   }
    722   ofst_->SetStart(s_p);
    723   ofst_->SetProperties(
    724       ShortestPathProperties(ofst_->Properties(kFstProperties, false)),
    725       kFstProperties);
    726 }
    727 
    728 
    729 // Finds transition with least weight between two states with label matching
    730 // paren_id and open/close paren type or a non-paren if kNoLabel.
    731 template<class Arc, class Queue>
    732 Arc PdtShortestPath<Arc, Queue>::GetPathArc(
    733     SearchState s, SearchState d, Label paren_id, bool open_paren) {
    734   Arc path_arc = kNoArc;
    735   for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state);
    736        !aiter.Done();
    737        aiter.Next()) {
    738     const Arc &arc = aiter.Value();
    739     if (arc.nextstate != d.state)
    740       continue;
    741     Label arc_paren_id = kNoLabel;
    742     typename unordered_map<Label, Label>::const_iterator pit
    743         = paren_id_map_.find(arc.ilabel);
    744     if (pit != paren_id_map_.end()) {
    745       arc_paren_id = pit->second;
    746       bool arc_open_paren = arc.ilabel == parens_[arc_paren_id].first;
    747       if (arc_open_paren != open_paren)
    748         continue;
    749     }
    750     if (arc_paren_id != paren_id)
    751       continue;
    752     if (arc.weight == Plus(arc.weight, path_arc.weight))
    753       path_arc = arc;
    754   }
    755   if (path_arc.nextstate == kNoStateId) {
    756     FSTERROR() << "PdtShortestPath::GetPathArc failed to find arc";
    757     error_ = true;
    758   }
    759   return path_arc;
    760 }
    761 
    762 template<class Arc, class Queue>
    763 const Arc PdtShortestPath<Arc, Queue>::kNoArc
    764     = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId);
    765 
    766 template<class Arc, class Queue>
    767 const uint8 PdtShortestPath<Arc, Queue>::kEnqueued = 0x10;
    768 
    769 template<class Arc, class Queue>
    770 const uint8 PdtShortestPath<Arc, Queue>::kExpanded = 0x20;
    771 
    772 template<class Arc, class Queue>
    773 const uint8 PdtShortestPath<Arc, Queue>::kFinished = 0x40;
    774 
    775 template<class Arc, class Queue>
    776 void ShortestPath(const Fst<Arc> &ifst,
    777                   const vector<pair<typename Arc::Label,
    778                                     typename Arc::Label> > &parens,
    779                   MutableFst<Arc> *ofst,
    780                   const PdtShortestPathOptions<Arc, Queue> &opts) {
    781   PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
    782   psp.ShortestPath(ofst);
    783 }
    784 
    785 template<class Arc>
    786 void ShortestPath(const Fst<Arc> &ifst,
    787                   const vector<pair<typename Arc::Label,
    788                                     typename Arc::Label> > &parens,
    789                   MutableFst<Arc> *ofst) {
    790   typedef FifoQueue<typename Arc::StateId> Queue;
    791   PdtShortestPathOptions<Arc, Queue> opts;
    792   PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
    793   psp.ShortestPath(ofst);
    794 }
    795 
    796 }  // namespace fst
    797 
    798 #endif  // FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
    799