Home | History | Annotate | Download | only in pdt
      1 // shortest-path.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // \file
     19 // Functions to find shortest paths in a PDT.
     20 
     21 #ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
     22 #define FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
     23 
     24 #include <fst/shortest-path.h>
     25 #include <fst/extensions/pdt/paren.h>
     26 #include <fst/extensions/pdt/pdt.h>
     27 
     28 #include <unordered_map>
     29 using std::tr1::unordered_map;
     30 using std::tr1::unordered_multimap;
     31 #include <tr1/unordered_set>
     32 using std::tr1::unordered_set;
     33 using std::tr1::unordered_multiset;
     34 #include <stack>
     35 #include <vector>
     36 using std::vector;
     37 
     38 namespace fst {
     39 
     40 template <class Arc, class Queue>
     41 struct PdtShortestPathOptions {
     42   bool keep_parentheses;
     43   bool path_gc;
     44 
     45   PdtShortestPathOptions(bool kp = false, bool gc = true)
     46       : keep_parentheses(kp), path_gc(gc) {}
     47 };
     48 
     49 
     50 // Class to store PDT shortest path results. Stores shortest path
     51 // tree info 'Distance()', Parent(), and ArcParent() information keyed
     52 // on two types:
     53 // (1) By SearchState: This is a usual node in a shortest path tree but:
     54 //    (a) is w.r.t a PDT search state - a pair of a PDT state and
     55 //        a 'start' state, which is either the PDT start state or
     56 //        the destination state of an open parenthesis.
     57 //    (b) the Distance() is from this 'start' state to the search state.
     58 //    (c) Parent().state is kNoLabel for the 'start' state.
     59 //
     60 // (2) By ParenSpec: This connects shortest path trees depending on the
     61 // the parenthesis taken. Given the parenthesis spec:
     62 //    (a) the Distance() is from the Parent() 'start' state to the
     63 //     parenthesis destination state.
     64 //    (b) the ArcParent() is the parenthesis arc.
     65 template <class Arc>
     66 class PdtShortestPathData {
     67  public:
     68   static const uint8 kFinal;
     69 
     70   typedef typename Arc::StateId StateId;
     71   typedef typename Arc::Weight Weight;
     72   typedef typename Arc::Label Label;
     73 
     74   struct SearchState {
     75     SearchState() : state(kNoStateId), start(kNoStateId) {}
     76 
     77     SearchState(StateId s, StateId t) : state(s), start(t) {}
     78 
     79     bool operator==(const SearchState &s) const {
     80       if (&s == this)
     81         return true;
     82       return s.state == this->state && s.start == this->start;
     83     }
     84 
     85     StateId state;  // PDT state
     86     StateId start;  // PDT paren 'source' state
     87   };
     88 
     89 
     90   // Specifies paren id, source and dest 'start' states of a paren.
     91   // These are the 'start' states of the respective sub-graphs.
     92   struct ParenSpec {
     93     ParenSpec()
     94         : paren_id(kNoLabel), src_start(kNoStateId), dest_start(kNoStateId) {}
     95 
     96     ParenSpec(Label id, StateId s, StateId d)
     97         : paren_id(id), src_start(s), dest_start(d) {}
     98 
     99     Label paren_id;        // Id of parenthesis
    100     StateId src_start;     // sub-graph 'start' state for paren source.
    101     StateId dest_start;    // sub-graph 'start' state for paren dest.
    102 
    103     bool operator==(const ParenSpec &x) const {
    104       if (&x == this)
    105         return true;
    106       return x.paren_id == this->paren_id &&
    107           x.src_start == this->src_start &&
    108           x.dest_start == this->dest_start;
    109     }
    110   };
    111 
    112   struct SearchData {
    113     SearchData() : distance(Weight::Zero()),
    114                    parent(kNoStateId, kNoStateId),
    115                    paren_id(kNoLabel),
    116                    flags(0) {}
    117 
    118     Weight distance;     // Distance to this state from PDT 'start' state
    119     SearchState parent;  // Parent state in shortest path tree
    120     int16 paren_id;      // If parent arc has paren, paren ID, o.w. kNoLabel
    121     uint8 flags;         // First byte reserved for PdtShortestPathData use
    122   };
    123 
    124   PdtShortestPathData(bool gc)
    125       : state_(kNoStateId, kNoStateId),
    126         paren_(kNoLabel, kNoStateId, kNoStateId),
    127         gc_(gc),
    128         nstates_(0),
    129         ngc_(0),
    130         finished_(false) {}
    131 
    132   ~PdtShortestPathData() {
    133     VLOG(1) << "opm size: " << paren_map_.size();
    134     VLOG(1) << "# of search states: " << nstates_;
    135     if (gc_)
    136       VLOG(1) << "# of GC'd search states: " << ngc_;
    137   }
    138 
    139   void Clear() {
    140     search_map_.clear();
    141     search_multimap_.clear();
    142     paren_map_.clear();
    143     state_ = SearchState(kNoStateId, kNoStateId);
    144     nstates_ = 0;
    145     ngc_ = 0;
    146   }
    147 
    148   Weight Distance(SearchState s) const {
    149     SearchData *data = GetSearchData(s);
    150     return data->distance;
    151   }
    152 
    153   Weight Distance(const ParenSpec &paren) const {
    154     SearchData *data = GetSearchData(paren);
    155     return data->distance;
    156   }
    157 
    158   SearchState Parent(SearchState s) const {
    159     SearchData *data = GetSearchData(s);
    160     return data->parent;
    161   }
    162 
    163   SearchState Parent(const ParenSpec &paren) const {
    164     SearchData *data = GetSearchData(paren);
    165     return data->parent;
    166   }
    167 
    168   Label ParenId(SearchState s) const {
    169     SearchData *data = GetSearchData(s);
    170     return data->paren_id;
    171   }
    172 
    173   uint8 Flags(SearchState s) const {
    174     SearchData *data = GetSearchData(s);
    175     return data->flags;
    176   }
    177 
    178   void SetDistance(SearchState s, Weight w) {
    179     SearchData *data = GetSearchData(s);
    180     data->distance = w;
    181   }
    182 
    183   void SetDistance(const ParenSpec &paren, Weight w) {
    184     SearchData *data = GetSearchData(paren);
    185     data->distance = w;
    186   }
    187 
    188   void SetParent(SearchState s, SearchState p) {
    189     SearchData *data = GetSearchData(s);
    190     data->parent = p;
    191   }
    192 
    193   void SetParent(const ParenSpec &paren, SearchState p) {
    194     SearchData *data = GetSearchData(paren);
    195     data->parent = p;
    196   }
    197 
    198   void SetParenId(SearchState s, Label p) {
    199     if (p >= 32768)
    200       FSTERROR() << "PdtShortestPathData: Paren ID does not fits in an int16";
    201     SearchData *data = GetSearchData(s);
    202     data->paren_id = p;
    203   }
    204 
    205   void SetFlags(SearchState s, uint8 f, uint8 mask) {
    206     SearchData *data = GetSearchData(s);
    207     data->flags &= ~mask;
    208     data->flags |= f & mask;
    209   }
    210 
    211   void GC(StateId s);
    212 
    213   void Finish() { finished_ = true; }
    214 
    215  private:
    216   static const Arc kNoArc;
    217   static const size_t kPrime0;
    218   static const size_t kPrime1;
    219   static const uint8 kInited;
    220   static const uint8 kMarked;
    221 
    222   // Hash for search state
    223   struct SearchStateHash {
    224     size_t operator()(const SearchState &s) const {
    225       return s.state + s.start * kPrime0;
    226     }
    227   };
    228 
    229   // Hash for paren map
    230   struct ParenHash {
    231     size_t operator()(const ParenSpec &paren) const {
    232       return paren.paren_id + paren.src_start * kPrime0 +
    233           paren.dest_start * kPrime1;
    234     }
    235   };
    236 
    237   typedef unordered_map<SearchState, SearchData, SearchStateHash> SearchMap;
    238 
    239   typedef unordered_multimap<StateId, StateId> SearchMultimap;
    240 
    241   // Hash map from paren spec to open paren data
    242   typedef unordered_map<ParenSpec, SearchData, ParenHash> ParenMap;
    243 
    244   SearchData *GetSearchData(SearchState s) const {
    245     if (s == state_)
    246       return state_data_;
    247     if (finished_) {
    248       typename SearchMap::iterator it = search_map_.find(s);
    249       if (it == search_map_.end())
    250         return &null_search_data_;
    251       state_ = s;
    252       return state_data_ = &(it->second);
    253     } else {
    254       state_ = s;
    255       state_data_ = &search_map_[s];
    256       if (!(state_data_->flags & kInited)) {
    257         ++nstates_;
    258         if (gc_)
    259           search_multimap_.insert(make_pair(s.start, s.state));
    260         state_data_->flags = kInited;
    261       }
    262       return state_data_;
    263     }
    264   }
    265 
    266   SearchData *GetSearchData(ParenSpec paren) const {
    267     if (paren == paren_)
    268       return paren_data_;
    269     if (finished_) {
    270       typename ParenMap::iterator it = paren_map_.find(paren);
    271       if (it == paren_map_.end())
    272         return &null_search_data_;
    273       paren_ = paren;
    274       return state_data_ = &(it->second);
    275     } else {
    276       paren_ = paren;
    277       return paren_data_ = &paren_map_[paren];
    278     }
    279   }
    280 
    281   mutable SearchMap search_map_;            // Maps from search state to data
    282   mutable SearchMultimap search_multimap_;  // Maps from 'start' to subgraph
    283   mutable ParenMap paren_map_;              // Maps paren spec to search data
    284   mutable SearchState state_;               // Last state accessed
    285   mutable SearchData *state_data_;          // Last state data accessed
    286   mutable ParenSpec paren_;                 // Last paren spec accessed
    287   mutable SearchData *paren_data_;          // Last paren data accessed
    288   bool gc_;                                 // Allow GC?
    289   mutable size_t nstates_;                  // Total number of search states
    290   size_t ngc_;                              // Number of GC'd search states
    291   mutable SearchData null_search_data_;     // Null search data
    292   bool finished_;                           // Read-only access when true
    293 
    294   DISALLOW_COPY_AND_ASSIGN(PdtShortestPathData);
    295 };
    296 
    297 // Deletes inaccessible search data from a given 'start' (open paren dest)
    298 // state. Assumes 'final' (close paren source or PDT final) states have
    299 // been flagged 'kFinal'.
    300 template<class Arc>
    301 void  PdtShortestPathData<Arc>::GC(StateId start) {
    302   if (!gc_)
    303     return;
    304   vector<StateId> final;
    305   for (typename SearchMultimap::iterator mmit = search_multimap_.find(start);
    306        mmit != search_multimap_.end() && mmit->first == start;
    307        ++mmit) {
    308     SearchState s(mmit->second, start);
    309     const SearchData &data = search_map_[s];
    310     if (data.flags & kFinal)
    311       final.push_back(s.state);
    312   }
    313 
    314   // Mark phase
    315   for (size_t i = 0; i < final.size(); ++i) {
    316     SearchState s(final[i], start);
    317     while (s.state != kNoLabel) {
    318       SearchData *sdata = &search_map_[s];
    319       if (sdata->flags & kMarked)
    320         break;
    321       sdata->flags |= kMarked;
    322       SearchState p = sdata->parent;
    323       if (p.start != start && p.start != kNoLabel) {  // entering sub-subgraph
    324         ParenSpec paren(sdata->paren_id, s.start, p.start);
    325         SearchData *pdata = &paren_map_[paren];
    326         s = pdata->parent;
    327       } else {
    328         s = p;
    329       }
    330     }
    331   }
    332 
    333   // Sweep phase
    334   typename SearchMultimap::iterator mmit = search_multimap_.find(start);
    335   while (mmit != search_multimap_.end() && mmit->first == start) {
    336     SearchState s(mmit->second, start);
    337     typename SearchMap::iterator mit = search_map_.find(s);
    338     const SearchData &data = mit->second;
    339     if (!(data.flags & kMarked)) {
    340       search_map_.erase(mit);
    341       ++ngc_;
    342     }
    343     search_multimap_.erase(mmit++);
    344   }
    345 }
    346 
    347 template<class Arc> const Arc PdtShortestPathData<Arc>::kNoArc
    348     = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId);
    349 
    350 template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime0 = 7853;
    351 
    352 template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime1 = 7867;
    353 
    354 template<class Arc> const uint8 PdtShortestPathData<Arc>::kInited = 0x01;
    355 
    356 template<class Arc> const uint8 PdtShortestPathData<Arc>::kFinal =  0x02;
    357 
    358 template<class Arc> const uint8 PdtShortestPathData<Arc>::kMarked = 0x04;
    359 
    360 
    361 // This computes the single source shortest (balanced) path (SSSP)
    362 // through a weighted PDT that has a bounded stack (i.e. is expandable
    363 // as an FST). It is a generalization of the classic SSSP graph
    364 // algorithm that removes a state s from a queue (defined by a
    365 // user-provided queue type) and relaxes the destination states of
    366 // transitions leaving s. In this PDT version, states that have
    367 // entering open parentheses are treated as source states for a
    368 // sub-graph SSSP problem with the shortest path up to the open
    369 // parenthesis being first saved. When a close parenthesis is then
    370 // encountered any balancing open parenthesis is examined for this
    371 // saved information and multiplied back. In this way, each sub-graph
    372 // is entered only once rather than repeatedly.  If every state in the
    373 // input PDT has the property that there is a unique 'start' state for
    374 // it with entering open parentheses, then this algorithm is quite
    375 // straight-forward. In general, this will not be the case, so the
    376 // algorithm (implicitly) creates a new graph where each state is a
    377 // pair of an original state and a possible parenthesis 'start' state
    378 // for that state.
    379 template<class Arc, class Queue>
    380 class PdtShortestPath {
    381  public:
    382   typedef typename Arc::StateId StateId;
    383   typedef typename Arc::Weight Weight;
    384   typedef typename Arc::Label Label;
    385 
    386   typedef PdtShortestPathData<Arc> SpData;
    387   typedef typename SpData::SearchState SearchState;
    388   typedef typename SpData::ParenSpec ParenSpec;
    389 
    390   typedef typename PdtParenReachable<Arc>::SetIterator StateSetIterator;
    391   typedef typename PdtBalanceData<Arc>::SetIterator CloseSourceIterator;
    392 
    393   PdtShortestPath(const Fst<Arc> &ifst,
    394                   const vector<pair<Label, Label> > &parens,
    395                   const PdtShortestPathOptions<Arc, Queue> &opts)
    396       : kFinal(SpData::kFinal),
    397         ifst_(ifst.Copy()),
    398         parens_(parens),
    399         keep_parens_(opts.keep_parentheses),
    400         start_(ifst.Start()),
    401         sp_data_(opts.path_gc),
    402         error_(false) {
    403 
    404     if ((Weight::Properties() & (kPath | kRightSemiring))
    405         != (kPath | kRightSemiring)) {
    406       FSTERROR() << "SingleShortestPath: Weight needs to have the path"
    407                  << " property and be right distributive: " << Weight::Type();
    408       error_ = true;
    409     }
    410 
    411     for (Label i = 0; i < parens.size(); ++i) {
    412       const pair<Label, Label>  &p = parens[i];
    413       paren_id_map_[p.first] = i;
    414       paren_id_map_[p.second] = i;
    415     }
    416   };
    417 
    418   ~PdtShortestPath() {
    419     VLOG(1) << "# of input states: " << CountStates(*ifst_);
    420     VLOG(1) << "# of enqueued: " << nenqueued_;
    421     VLOG(1) << "cpmm size: " << close_paren_multimap_.size();
    422     delete ifst_;
    423   }
    424 
    425   void ShortestPath(MutableFst<Arc> *ofst) {
    426     Init(ofst);
    427     GetDistance(start_);
    428     GetPath();
    429     sp_data_.Finish();
    430     if (error_) ofst->SetProperties(kError, kError);
    431   }
    432 
    433   const PdtShortestPathData<Arc> &GetShortestPathData() const {
    434     return sp_data_;
    435   }
    436 
    437   PdtBalanceData<Arc> *GetBalanceData() { return &balance_data_; }
    438 
    439  private:
    440   static const Arc kNoArc;
    441   static const uint8 kEnqueued;
    442   static const uint8 kExpanded;
    443   const uint8 kFinal;
    444 
    445  public:
    446   // Hash multimap from close paren label to an paren arc.
    447   typedef unordered_multimap<ParenState<Arc>, Arc,
    448                         typename ParenState<Arc>::Hash> CloseParenMultimap;
    449 
    450   const CloseParenMultimap &GetCloseParenMultimap() const {
    451     return close_paren_multimap_;
    452   }
    453 
    454  private:
    455   void Init(MutableFst<Arc> *ofst);
    456   void GetDistance(StateId start);
    457   void ProcFinal(SearchState s);
    458   void ProcArcs(SearchState s);
    459   void ProcOpenParen(Label paren_id, SearchState s, Arc arc, Weight w);
    460   void ProcCloseParen(Label paren_id, SearchState s, const Arc &arc, Weight w);
    461   void ProcNonParen(SearchState s, const Arc &arc, Weight w);
    462   void Relax(SearchState s, SearchState t, Arc arc, Weight w, Label paren_id);
    463   void Enqueue(SearchState d);
    464   void GetPath();
    465   Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open);
    466 
    467   Fst<Arc> *ifst_;
    468   MutableFst<Arc> *ofst_;
    469   const vector<pair<Label, Label> > &parens_;
    470   bool keep_parens_;
    471   Queue *state_queue_;                   // current state queue
    472   StateId start_;
    473   Weight f_distance_;
    474   SearchState f_parent_;
    475   SpData sp_data_;
    476   unordered_map<Label, Label> paren_id_map_;
    477   CloseParenMultimap close_paren_multimap_;
    478   PdtBalanceData<Arc> balance_data_;
    479   ssize_t nenqueued_;
    480   bool error_;
    481 
    482   DISALLOW_COPY_AND_ASSIGN(PdtShortestPath);
    483 };
    484 
    485 template<class Arc, class Queue>
    486 void PdtShortestPath<Arc, Queue>::Init(MutableFst<Arc> *ofst) {
    487   ofst_ = ofst;
    488   ofst->DeleteStates();
    489   ofst->SetInputSymbols(ifst_->InputSymbols());
    490   ofst->SetOutputSymbols(ifst_->OutputSymbols());
    491 
    492   if (ifst_->Start() == kNoStateId)
    493     return;
    494 
    495   f_distance_ = Weight::Zero();
    496   f_parent_ = SearchState(kNoStateId, kNoStateId);
    497 
    498   sp_data_.Clear();
    499   close_paren_multimap_.clear();
    500   balance_data_.Clear();
    501   nenqueued_ = 0;
    502 
    503   // Find open parens per destination state and close parens per source state.
    504   for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) {
    505     StateId s = siter.Value();
    506     for (ArcIterator<Fst<Arc> > aiter(*ifst_, s);
    507          !aiter.Done(); aiter.Next()) {
    508       const Arc &arc = aiter.Value();
    509       typename unordered_map<Label, Label>::const_iterator pit
    510           = paren_id_map_.find(arc.ilabel);
    511       if (pit != paren_id_map_.end()) {               // Is a paren?
    512         Label paren_id = pit->second;
    513         if (arc.ilabel == parens_[paren_id].first) {  // Open paren
    514           balance_data_.OpenInsert(paren_id, arc.nextstate);
    515         } else {                                      // Close paren
    516           ParenState<Arc> paren_state(paren_id, s);
    517           close_paren_multimap_.insert(make_pair(paren_state, arc));
    518         }
    519       }
    520     }
    521   }
    522 }
    523 
    524 // Computes the shortest distance stored in a recursive way. Each
    525 // sub-graph (i.e. different paren 'start' state) begins with weight One().
    526 template<class Arc, class Queue>
    527 void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) {
    528   if (start == kNoStateId)
    529     return;
    530 
    531   Queue state_queue;
    532   state_queue_ = &state_queue;
    533   SearchState q(start, start);
    534   Enqueue(q);
    535   sp_data_.SetDistance(q, Weight::One());
    536 
    537   while (!state_queue_->Empty()) {
    538     StateId state = state_queue_->Head();
    539     state_queue_->Dequeue();
    540     SearchState s(state, start);
    541     sp_data_.SetFlags(s, 0, kEnqueued);
    542     ProcFinal(s);
    543     ProcArcs(s);
    544     sp_data_.SetFlags(s, kExpanded, kExpanded);
    545   }
    546   balance_data_.FinishInsert(start);
    547   sp_data_.GC(start);
    548 }
    549 
    550 // Updates best complete path.
    551 template<class Arc, class Queue>
    552 void PdtShortestPath<Arc, Queue>::ProcFinal(SearchState s) {
    553   if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) {
    554     Weight w = Times(sp_data_.Distance(s),
    555                      ifst_->Final(s.state));
    556     if (f_distance_ != Plus(f_distance_, w)) {
    557       if (f_parent_.state != kNoStateId)
    558         sp_data_.SetFlags(f_parent_, 0, kFinal);
    559       sp_data_.SetFlags(s, kFinal, kFinal);
    560 
    561       f_distance_ = Plus(f_distance_, w);
    562       f_parent_ = s;
    563     }
    564   }
    565 }
    566 
    567 // Processes all arcs leaving the state s.
    568 template<class Arc, class Queue>
    569 void PdtShortestPath<Arc, Queue>::ProcArcs(SearchState s) {
    570   for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state);
    571        !aiter.Done();
    572        aiter.Next()) {
    573     Arc arc = aiter.Value();
    574     Weight w = Times(sp_data_.Distance(s), arc.weight);
    575 
    576     typename unordered_map<Label, Label>::const_iterator pit
    577         = paren_id_map_.find(arc.ilabel);
    578     if (pit != paren_id_map_.end()) {  // Is a paren?
    579       Label paren_id = pit->second;
    580       if (arc.ilabel == parens_[paren_id].first)
    581         ProcOpenParen(paren_id, s, arc, w);
    582       else
    583         ProcCloseParen(paren_id, s, arc, w);
    584     } else {
    585       ProcNonParen(s, arc, w);
    586     }
    587   }
    588 }
    589 
    590 // Saves the shortest path info for reaching this parenthesis
    591 // and starts a new SSSP in the sub-graph pointed to by the parenthesis
    592 // if previously unvisited. Otherwise it finds any previously encountered
    593 // closing parentheses and relaxes them using the recursively stored
    594 // shortest distance to them.
    595 template<class Arc, class Queue> inline
    596 void PdtShortestPath<Arc, Queue>::ProcOpenParen(
    597     Label paren_id, SearchState s, Arc arc, Weight w) {
    598 
    599   SearchState d(arc.nextstate, arc.nextstate);
    600   ParenSpec paren(paren_id, s.start, d.start);
    601   Weight pdist = sp_data_.Distance(paren);
    602   if (pdist != Plus(pdist, w)) {
    603     sp_data_.SetDistance(paren, w);
    604     sp_data_.SetParent(paren, s);
    605     Weight dist = sp_data_.Distance(d);
    606     if (dist == Weight::Zero()) {
    607       Queue *state_queue = state_queue_;
    608       GetDistance(d.start);
    609       state_queue_ = state_queue;
    610     }
    611     for (CloseSourceIterator set_iter =
    612              balance_data_.Find(paren_id, arc.nextstate);
    613          !set_iter.Done(); set_iter.Next()) {
    614       SearchState cpstate(set_iter.Element(), d.start);
    615       ParenState<Arc> paren_state(paren_id, cpstate.state);
    616       for (typename CloseParenMultimap::const_iterator cpit =
    617                close_paren_multimap_.find(paren_state);
    618            cpit != close_paren_multimap_.end() && paren_state == cpit->first;
    619            ++cpit) {
    620         const Arc &cparc = cpit->second;
    621         Weight cpw = Times(w, Times(sp_data_.Distance(cpstate),
    622                                     cparc.weight));
    623         Relax(cpstate, s, cparc, cpw, paren_id);
    624       }
    625     }
    626   }
    627 }
    628 
    629 // Saves the correspondence between each closing parenthesis and its
    630 // balancing open parenthesis info. Relaxes any close parenthesis
    631 // destination state that has a balancing previously encountered open
    632 // parenthesis.
    633 template<class Arc, class Queue> inline
    634 void PdtShortestPath<Arc, Queue>::ProcCloseParen(
    635     Label paren_id, SearchState s, const Arc &arc, Weight w) {
    636   ParenState<Arc> paren_state(paren_id, s.start);
    637   if (!(sp_data_.Flags(s) & kExpanded)) {
    638     balance_data_.CloseInsert(paren_id, s.start, s.state);
    639     sp_data_.SetFlags(s, kFinal, kFinal);
    640   }
    641 }
    642 
    643 // For non-parentheses, classical relaxation.
    644 template<class Arc, class Queue> inline
    645 void PdtShortestPath<Arc, Queue>::ProcNonParen(
    646     SearchState s, const Arc &arc, Weight w) {
    647   Relax(s, s, arc, w, kNoLabel);
    648 }
    649 
    650 // Classical relaxation on the search graph for 'arc' from state 's'.
    651 // State 't' is in the same sub-graph as the nextstate should be (i.e.
    652 // has the same paren 'start'.
    653 template<class Arc, class Queue> inline
    654 void PdtShortestPath<Arc, Queue>::Relax(
    655     SearchState s, SearchState t, Arc arc, Weight w, Label paren_id) {
    656   SearchState d(arc.nextstate, t.start);
    657   Weight dist = sp_data_.Distance(d);
    658   if (dist != Plus(dist, w)) {
    659     sp_data_.SetParent(d, s);
    660     sp_data_.SetParenId(d, paren_id);
    661     sp_data_.SetDistance(d, Plus(dist, w));
    662     Enqueue(d);
    663   }
    664 }
    665 
    666 template<class Arc, class Queue> inline
    667 void PdtShortestPath<Arc, Queue>::Enqueue(SearchState s) {
    668   if (!(sp_data_.Flags(s) & kEnqueued)) {
    669     state_queue_->Enqueue(s.state);
    670     sp_data_.SetFlags(s, kEnqueued, kEnqueued);
    671     ++nenqueued_;
    672   } else {
    673     state_queue_->Update(s.state);
    674   }
    675 }
    676 
    677 // Follows parent pointers to find the shortest path. Uses a stack
    678 // since the shortest distance is stored recursively.
    679 template<class Arc, class Queue>
    680 void PdtShortestPath<Arc, Queue>::GetPath() {
    681   SearchState s = f_parent_, d = SearchState(kNoStateId, kNoStateId);
    682   StateId s_p = kNoStateId, d_p = kNoStateId;
    683   Arc arc(kNoArc);
    684   Label paren_id = kNoLabel;
    685   stack<ParenSpec> paren_stack;
    686   while (s.state != kNoStateId) {
    687     d_p = s_p;
    688     s_p = ofst_->AddState();
    689     if (d.state == kNoStateId) {
    690       ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state));
    691     } else {
    692       if (paren_id != kNoLabel) {                     // paren?
    693         if (arc.ilabel == parens_[paren_id].first) {  // open paren
    694           paren_stack.pop();
    695         } else {                                      // close paren
    696           ParenSpec paren(paren_id, d.start, s.start);
    697           paren_stack.push(paren);
    698         }
    699         if (!keep_parens_)
    700           arc.ilabel = arc.olabel = 0;
    701       }
    702       arc.nextstate = d_p;
    703       ofst_->AddArc(s_p, arc);
    704     }
    705     d = s;
    706     s = sp_data_.Parent(d);
    707     paren_id = sp_data_.ParenId(d);
    708     if (s.state != kNoStateId) {
    709       arc = GetPathArc(s, d, paren_id, false);
    710     } else if (!paren_stack.empty()) {
    711       ParenSpec paren = paren_stack.top();
    712       s = sp_data_.Parent(paren);
    713       paren_id = paren.paren_id;
    714       arc = GetPathArc(s, d, paren_id, true);
    715     }
    716   }
    717   ofst_->SetStart(s_p);
    718   ofst_->SetProperties(
    719       ShortestPathProperties(ofst_->Properties(kFstProperties, false)),
    720       kFstProperties);
    721 }
    722 
    723 
    724 // Finds transition with least weight between two states with label matching
    725 // paren_id and open/close paren type or a non-paren if kNoLabel.
    726 template<class Arc, class Queue>
    727 Arc PdtShortestPath<Arc, Queue>::GetPathArc(
    728     SearchState s, SearchState d, Label paren_id, bool open_paren) {
    729   Arc path_arc = kNoArc;
    730   for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state);
    731        !aiter.Done();
    732        aiter.Next()) {
    733     const Arc &arc = aiter.Value();
    734     if (arc.nextstate != d.state)
    735       continue;
    736     Label arc_paren_id = kNoLabel;
    737     typename unordered_map<Label, Label>::const_iterator pit
    738         = paren_id_map_.find(arc.ilabel);
    739     if (pit != paren_id_map_.end()) {
    740       arc_paren_id = pit->second;
    741       bool arc_open_paren = arc.ilabel == parens_[arc_paren_id].first;
    742       if (arc_open_paren != open_paren)
    743         continue;
    744     }
    745     if (arc_paren_id != paren_id)
    746       continue;
    747     if (arc.weight == Plus(arc.weight, path_arc.weight))
    748       path_arc = arc;
    749   }
    750   if (path_arc.nextstate == kNoStateId) {
    751     FSTERROR() << "PdtShortestPath::GetPathArc failed to find arc";
    752     error_ = true;
    753   }
    754   return path_arc;
    755 }
    756 
    757 template<class Arc, class Queue>
    758 const Arc PdtShortestPath<Arc, Queue>::kNoArc
    759     = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId);
    760 
    761 template<class Arc, class Queue>
    762 const uint8 PdtShortestPath<Arc, Queue>::kEnqueued = 0x10;
    763 
    764 template<class Arc, class Queue>
    765 const uint8 PdtShortestPath<Arc, Queue>::kExpanded = 0x20;
    766 
    767 template<class Arc, class Queue>
    768 void ShortestPath(const Fst<Arc> &ifst,
    769                   const vector<pair<typename Arc::Label,
    770                                     typename Arc::Label> > &parens,
    771                   MutableFst<Arc> *ofst,
    772                   const PdtShortestPathOptions<Arc, Queue> &opts) {
    773   PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
    774   psp.ShortestPath(ofst);
    775 }
    776 
    777 template<class Arc>
    778 void ShortestPath(const Fst<Arc> &ifst,
    779                   const vector<pair<typename Arc::Label,
    780                                     typename Arc::Label> > &parens,
    781                   MutableFst<Arc> *ofst) {
    782   typedef FifoQueue<typename Arc::StateId> Queue;
    783   PdtShortestPathOptions<Arc, Queue> opts;
    784   PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
    785   psp.ShortestPath(ofst);
    786 }
    787 
    788 }  // namespace fst
    789 
    790 #endif  // FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
    791