Home | History | Annotate | Download | only in pdt
      1 // paren.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // Common classes for PDT parentheses
     19 
     20 // \file
     21 
     22 #ifndef FST_EXTENSIONS_PDT_PAREN_H_
     23 #define FST_EXTENSIONS_PDT_PAREN_H_
     24 
     25 #include <algorithm>
     26 #include <unordered_map>
     27 using std::tr1::unordered_map;
     28 using std::tr1::unordered_multimap;
     29 #include <tr1/unordered_set>
     30 using std::tr1::unordered_set;
     31 using std::tr1::unordered_multiset;
     32 #include <set>
     33 
     34 #include <fst/extensions/pdt/pdt.h>
     35 #include <fst/extensions/pdt/collection.h>
     36 #include <fst/fst.h>
     37 #include <fst/dfs-visit.h>
     38 
     39 
     40 namespace fst {
     41 
     42 //
     43 // ParenState: Pair of an open (close) parenthesis and
     44 // its destination (source) state.
     45 //
     46 
     47 template <class A>
     48 class ParenState {
     49  public:
     50   typedef typename A::Label Label;
     51   typedef typename A::StateId StateId;
     52 
     53   struct Hash {
     54     size_t operator()(const ParenState<A> &p) const {
     55       return p.paren_id + p.state_id * kPrime;
     56     }
     57   };
     58 
     59   Label paren_id;     // ID of open (close) paren
     60   StateId state_id;   // destination (source) state of open (close) paren
     61 
     62   ParenState() : paren_id(kNoLabel), state_id(kNoStateId) {}
     63 
     64   ParenState(Label p, StateId s) : paren_id(p), state_id(s) {}
     65 
     66   bool operator==(const ParenState<A> &p) const {
     67     if (&p == this)
     68       return true;
     69     return p.paren_id == this->paren_id && p.state_id == this->state_id;
     70   }
     71 
     72   bool operator!=(const ParenState<A> &p) const { return !(p == *this); }
     73 
     74   bool operator<(const ParenState<A> &p) const {
     75     return paren_id < this->paren.id ||
     76         (p.paren_id == this->paren.id && p.state_id < this->state_id);
     77   }
     78 
     79  private:
     80   static const size_t kPrime;
     81 };
     82 
     83 template <class A>
     84 const size_t ParenState<A>::kPrime = 7853;
     85 
     86 
     87 // Creates an FST-style iterator from STL map and iterator.
     88 template <class M>
     89 class MapIterator {
     90  public:
     91   typedef typename M::const_iterator StlIterator;
     92   typedef typename M::value_type PairType;
     93   typedef typename PairType::second_type ValueType;
     94 
     95   MapIterator(const M &m, StlIterator iter)
     96       : map_(m), begin_(iter), iter_(iter) {}
     97 
     98   bool Done() const {
     99     return iter_ == map_.end() || iter_->first != begin_->first;
    100   }
    101 
    102   ValueType Value() const { return iter_->second; }
    103   void Next() { ++iter_; }
    104   void Reset() { iter_ = begin_; }
    105 
    106  private:
    107   const M &map_;
    108   StlIterator begin_;
    109   StlIterator iter_;
    110 };
    111 
    112 //
    113 // PdtParenReachable: Provides various parenthesis reachability information
    114 // on a PDT.
    115 //
    116 
    117 template <class A>
    118 class PdtParenReachable {
    119  public:
    120   typedef typename A::StateId StateId;
    121   typedef typename A::Label Label;
    122  public:
    123   // Maps from state ID to reachable paren IDs from (to) that state.
    124   typedef unordered_multimap<StateId, Label> ParenMultiMap;
    125 
    126   // Maps from paren ID and state ID to reachable state set ID
    127   typedef unordered_map<ParenState<A>, ssize_t,
    128                    typename ParenState<A>::Hash> StateSetMap;
    129 
    130   // Maps from paren ID and state ID to arcs exiting that state with that
    131   // Label.
    132   typedef unordered_multimap<ParenState<A>, A,
    133                         typename ParenState<A>::Hash> ParenArcMultiMap;
    134 
    135   typedef MapIterator<ParenMultiMap> ParenIterator;
    136 
    137   typedef MapIterator<ParenArcMultiMap> ParenArcIterator;
    138 
    139   typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
    140 
    141   // Computes close (open) parenthesis reachabilty information for
    142   // a PDT with bounded stack.
    143   PdtParenReachable(const Fst<A> &fst,
    144                     const vector<pair<Label, Label> > &parens, bool close)
    145       : fst_(fst),
    146         parens_(parens),
    147         close_(close) {
    148     for (Label i = 0; i < parens.size(); ++i) {
    149       const pair<Label, Label>  &p = parens[i];
    150       paren_id_map_[p.first] = i;
    151       paren_id_map_[p.second] = i;
    152     }
    153 
    154     if (close_) {
    155       StateId start = fst.Start();
    156       if (start == kNoStateId)
    157         return;
    158       DFSearch(start, start);
    159     } else {
    160       FSTERROR() << "PdtParenReachable: open paren info not implemented";
    161     }
    162   }
    163 
    164   // Given a state ID, returns an iterator over paren IDs
    165   // for close (open) parens reachable from that state along balanced
    166   // paths.
    167   ParenIterator FindParens(StateId s) const {
    168     return ParenIterator(paren_multimap_, paren_multimap_.find(s));
    169   }
    170 
    171   // Given a paren ID and a state ID s, returns an iterator over
    172   // states that can be reached along balanced paths from (to) s that
    173   // have have close (open) parentheses matching the paren ID exiting
    174   // (entering) those states.
    175   SetIterator FindStates(Label paren_id, StateId s) const {
    176     ParenState<A> paren_state(paren_id, s);
    177     typename StateSetMap::const_iterator id_it = set_map_.find(paren_state);
    178     if (id_it == set_map_.end()) {
    179       return state_sets_.FindSet(-1);
    180     } else {
    181       return state_sets_.FindSet(id_it->second);
    182     }
    183   }
    184 
    185   // Given a paren Id and a state ID s, return an iterator over
    186   // arcs that exit (enter) s and are labeled with a close (open)
    187   // parenthesis matching the paren ID.
    188   ParenArcIterator FindParenArcs(Label paren_id, StateId s) const {
    189     ParenState<A> paren_state(paren_id, s);
    190     return ParenArcIterator(paren_arc_multimap_,
    191                             paren_arc_multimap_.find(paren_state));
    192   }
    193 
    194  private:
    195   // DFS that gathers paren and state set information.
    196   // Bool returns false when cycle detected.
    197   bool DFSearch(StateId s, StateId start);
    198 
    199   // Unions state sets together gathered by the DFS.
    200   void ComputeStateSet(StateId s);
    201 
    202   // Gather state set(s) from state 'nexts'.
    203   void UpdateStateSet(StateId nexts, set<Label> *paren_set,
    204                       vector< set<StateId> > *state_sets) const;
    205 
    206   const Fst<A> &fst_;
    207   const vector<pair<Label, Label> > &parens_;         // Paren ID -> Labels
    208   bool close_;                                        // Close/open paren info?
    209   unordered_map<Label, Label> paren_id_map_;               // Paren labels -> ID
    210   ParenMultiMap paren_multimap_;                      // Paren reachability
    211   ParenArcMultiMap paren_arc_multimap_;               // Paren Arcs
    212   vector<char> state_color_;                          // DFS state
    213   mutable Collection<ssize_t, StateId> state_sets_;   // Reachable states -> ID
    214   StateSetMap set_map_;                               // ID -> Reachable states
    215   DISALLOW_COPY_AND_ASSIGN(PdtParenReachable);
    216 };
    217 
    218 // DFS that gathers paren and state set information.
    219 template <class A>
    220 bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) {
    221   if (s >= state_color_.size())
    222     state_color_.resize(s + 1, kDfsWhite);
    223 
    224   if (state_color_[s] == kDfsBlack)
    225     return true;
    226 
    227   if (state_color_[s] == kDfsGrey)
    228     return false;
    229 
    230   state_color_[s] = kDfsGrey;
    231 
    232   for (ArcIterator<Fst<A> > aiter(fst_, s);
    233        !aiter.Done();
    234        aiter.Next()) {
    235     const A &arc = aiter.Value();
    236 
    237     typename unordered_map<Label, Label>::const_iterator pit
    238         = paren_id_map_.find(arc.ilabel);
    239     if (pit != paren_id_map_.end()) {               // paren?
    240       Label paren_id = pit->second;
    241       if (arc.ilabel == parens_[paren_id].first) {  // open paren
    242         DFSearch(arc.nextstate, arc.nextstate);
    243         for (SetIterator set_iter = FindStates(paren_id, arc.nextstate);
    244              !set_iter.Done(); set_iter.Next()) {
    245           for (ParenArcIterator paren_arc_iter =
    246                    FindParenArcs(paren_id, set_iter.Element());
    247                !paren_arc_iter.Done();
    248                paren_arc_iter.Next()) {
    249             const A &cparc = paren_arc_iter.Value();
    250             DFSearch(cparc.nextstate, start);
    251           }
    252         }
    253       }
    254     } else {                                       // non-paren
    255       if(!DFSearch(arc.nextstate, start)) {
    256         FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
    257         return true;
    258       }
    259     }
    260   }
    261   ComputeStateSet(s);
    262   state_color_[s] = kDfsBlack;
    263   return true;
    264 }
    265 
    266 // Unions state sets together gathered by the DFS.
    267 template <class A>
    268 void PdtParenReachable<A>::ComputeStateSet(StateId s) {
    269   set<Label> paren_set;
    270   vector< set<StateId> > state_sets(parens_.size());
    271   for (ArcIterator< Fst<A> > aiter(fst_, s);
    272        !aiter.Done();
    273        aiter.Next()) {
    274     const A &arc = aiter.Value();
    275 
    276     typename unordered_map<Label, Label>::const_iterator pit
    277         = paren_id_map_.find(arc.ilabel);
    278     if (pit != paren_id_map_.end()) {               // paren?
    279       Label paren_id = pit->second;
    280       if (arc.ilabel == parens_[paren_id].first) {  // open paren
    281         for (SetIterator set_iter =
    282                  FindStates(paren_id, arc.nextstate);
    283              !set_iter.Done(); set_iter.Next()) {
    284           for (ParenArcIterator paren_arc_iter =
    285                    FindParenArcs(paren_id, set_iter.Element());
    286                !paren_arc_iter.Done();
    287                paren_arc_iter.Next()) {
    288             const A &cparc = paren_arc_iter.Value();
    289             UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
    290           }
    291         }
    292       } else {                                      // close paren
    293         paren_set.insert(paren_id);
    294         state_sets[paren_id].insert(s);
    295         ParenState<A> paren_state(paren_id, s);
    296         paren_arc_multimap_.insert(make_pair(paren_state, arc));
    297       }
    298     } else {                                        // non-paren
    299       UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
    300     }
    301   }
    302 
    303   vector<StateId> state_set;
    304   for (typename set<Label>::iterator paren_iter = paren_set.begin();
    305        paren_iter != paren_set.end(); ++paren_iter) {
    306     state_set.clear();
    307     Label paren_id = *paren_iter;
    308     paren_multimap_.insert(make_pair(s, paren_id));
    309     for (typename set<StateId>::iterator state_iter
    310              = state_sets[paren_id].begin();
    311          state_iter != state_sets[paren_id].end();
    312          ++state_iter) {
    313       state_set.push_back(*state_iter);
    314     }
    315     ParenState<A> paren_state(paren_id, s);
    316     set_map_[paren_state] = state_sets_.FindId(state_set);
    317   }
    318 }
    319 
    320 // Gather state set(s) from state 'nexts'.
    321 template <class A>
    322 void PdtParenReachable<A>::UpdateStateSet(
    323     StateId nexts, set<Label> *paren_set,
    324     vector< set<StateId> > *state_sets) const {
    325   for(ParenIterator paren_iter = FindParens(nexts);
    326       !paren_iter.Done(); paren_iter.Next()) {
    327     Label paren_id = paren_iter.Value();
    328     paren_set->insert(paren_id);
    329     for (SetIterator set_iter = FindStates(paren_id, nexts);
    330          !set_iter.Done(); set_iter.Next()) {
    331       (*state_sets)[paren_id].insert(set_iter.Element());
    332     }
    333   }
    334 }
    335 
    336 
    337 // Store balancing parenthesis data for a PDT. Allows on-the-fly
    338 // construction (e.g. in PdtShortestPath) unlike PdtParenReachable above.
    339 template <class A>
    340 class PdtBalanceData {
    341  public:
    342   typedef typename A::StateId StateId;
    343   typedef typename A::Label Label;
    344 
    345   // Hash set for open parens
    346   typedef unordered_set<ParenState<A>, typename ParenState<A>::Hash> OpenParenSet;
    347 
    348   // Maps from open paren destination state to parenthesis ID.
    349   typedef unordered_multimap<StateId, Label> OpenParenMap;
    350 
    351   // Maps from open paren state to source states of matching close parens
    352   typedef unordered_multimap<ParenState<A>, StateId,
    353                         typename ParenState<A>::Hash> CloseParenMap;
    354 
    355   // Maps from open paren state to close source set ID
    356   typedef unordered_map<ParenState<A>, ssize_t,
    357                    typename ParenState<A>::Hash> CloseSourceMap;
    358 
    359   typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
    360 
    361   PdtBalanceData() {}
    362 
    363   void Clear() {
    364     open_paren_map_.clear();
    365     close_paren_map_.clear();
    366   }
    367 
    368   // Adds an open parenthesis with destination state 'open_dest'.
    369   void OpenInsert(Label paren_id, StateId open_dest) {
    370     ParenState<A> key(paren_id, open_dest);
    371     if (!open_paren_set_.count(key)) {
    372       open_paren_set_.insert(key);
    373       open_paren_map_.insert(make_pair(open_dest, paren_id));
    374     }
    375   }
    376 
    377   // Adds a matching closing parenthesis with source state
    378   // 'close_source' that balances an open_parenthesis with destination
    379   // state 'open_dest' if OpenInsert() previously called
    380   // (o.w. CloseInsert() does nothing).
    381   void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) {
    382     ParenState<A> key(paren_id, open_dest);
    383     if (open_paren_set_.count(key))
    384       close_paren_map_.insert(make_pair(key, close_source));
    385   }
    386 
    387   // Find close paren source states matching an open parenthesis.
    388   // Methods that follow, iterate through those matching states.
    389   // Should be called only after FinishInsert(open_dest).
    390   SetIterator Find(Label paren_id, StateId open_dest) {
    391     ParenState<A> close_key(paren_id, open_dest);
    392     typename CloseSourceMap::const_iterator id_it =
    393         close_source_map_.find(close_key);
    394     if (id_it == close_source_map_.end()) {
    395       return close_source_sets_.FindSet(-1);
    396     } else {
    397       return close_source_sets_.FindSet(id_it->second);
    398     }
    399   }
    400 
    401   // Call when all open and close parenthesis insertions wrt open
    402   // parentheses entering 'open_dest' are finished. Must be called
    403   // before Find(open_dest). Stores close paren source state sets
    404   // efficiently.
    405   void FinishInsert(StateId open_dest) {
    406     vector<StateId> close_sources;
    407     for (typename OpenParenMap::iterator oit = open_paren_map_.find(open_dest);
    408          oit != open_paren_map_.end() && oit->first == open_dest;) {
    409       Label paren_id = oit->second;
    410       close_sources.clear();
    411       ParenState<A> okey(paren_id, open_dest);
    412       open_paren_set_.erase(open_paren_set_.find(okey));
    413       for (typename CloseParenMap::iterator cit = close_paren_map_.find(okey);
    414            cit != close_paren_map_.end() && cit->first == okey;) {
    415         close_sources.push_back(cit->second);
    416         close_paren_map_.erase(cit++);
    417       }
    418       sort(close_sources.begin(), close_sources.end());
    419       typename vector<StateId>::iterator unique_end =
    420           unique(close_sources.begin(), close_sources.end());
    421       close_sources.resize(unique_end - close_sources.begin());
    422 
    423       if (!close_sources.empty())
    424         close_source_map_[okey] = close_source_sets_.FindId(close_sources);
    425       open_paren_map_.erase(oit++);
    426     }
    427   }
    428 
    429   // Return a new balance data object representing the reversed balance
    430   // information.
    431   PdtBalanceData<A> *Reverse(StateId num_states,
    432                                StateId num_split,
    433                                StateId state_id_shift) const;
    434 
    435  private:
    436   OpenParenSet open_paren_set_;                      // open par. at dest?
    437 
    438   OpenParenMap open_paren_map_;                      // open parens per state
    439   ParenState<A> open_dest_;                          // cur open dest. state
    440   typename OpenParenMap::const_iterator open_iter_;  // cur open parens/state
    441 
    442   CloseParenMap close_paren_map_;                    // close states/open
    443                                                      //  paren and state
    444 
    445   CloseSourceMap close_source_map_;                  // paren, state to set ID
    446   mutable Collection<ssize_t, StateId> close_source_sets_;
    447 };
    448 
    449 // Return a new balance data object representing the reversed balance
    450 // information.
    451 template <class A>
    452 PdtBalanceData<A> *PdtBalanceData<A>::Reverse(
    453     StateId num_states,
    454     StateId num_split,
    455     StateId state_id_shift) const {
    456   PdtBalanceData<A> *bd = new PdtBalanceData<A>;
    457   unordered_set<StateId> close_sources;
    458   StateId split_size = num_states / num_split;
    459 
    460   for (StateId i = 0; i < num_states; i+= split_size) {
    461     close_sources.clear();
    462 
    463     for (typename CloseSourceMap::const_iterator
    464              sit = close_source_map_.begin();
    465          sit != close_source_map_.end();
    466          ++sit) {
    467       ParenState<A> okey = sit->first;
    468       StateId open_dest = okey.state_id;
    469       Label paren_id = okey.paren_id;
    470       for (SetIterator set_iter = close_source_sets_.FindSet(sit->second);
    471            !set_iter.Done(); set_iter.Next()) {
    472         StateId close_source = set_iter.Element();
    473         if ((close_source < i) || (close_source >= i + split_size))
    474           continue;
    475         close_sources.insert(close_source + state_id_shift);
    476         bd->OpenInsert(paren_id, close_source + state_id_shift);
    477         bd->CloseInsert(paren_id, close_source + state_id_shift,
    478                         open_dest + state_id_shift);
    479       }
    480     }
    481 
    482     for (typename unordered_set<StateId>::const_iterator it
    483              = close_sources.begin();
    484          it != close_sources.end();
    485          ++it) {
    486       bd->FinishInsert(*it);
    487     }
    488 
    489   }
    490   return bd;
    491 }
    492 
    493 
    494 }  // namespace fst
    495 
    496 #endif  // FST_EXTENSIONS_PDT_PAREN_H_
    497