Home | History | Annotate | Download | only in pdt
      1 // paren.h
      2 
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 //
     15 // Copyright 2005-2010 Google, Inc.
     16 // Author: riley (at) google.com (Michael Riley)
     17 //
     18 // Common classes for PDT parentheses
     19 
     20 // \file
     21 
     22 #ifndef FST_EXTENSIONS_PDT_PAREN_H_
     23 #define FST_EXTENSIONS_PDT_PAREN_H_
     24 
     25 #include <algorithm>
     26 #include <tr1/unordered_map>
     27 using std::tr1::unordered_map;
     28 using std::tr1::unordered_multimap;
     29 #include <tr1/unordered_set>
     30 using std::tr1::unordered_set;
     31 using std::tr1::unordered_multiset;
     32 #include <set>
     33 
     34 #include <fst/extensions/pdt/pdt.h>
     35 #include <fst/extensions/pdt/collection.h>
     36 #include <fst/fst.h>
     37 #include <fst/dfs-visit.h>
     38 
     39 
     40 namespace fst {
     41 
     42 //
     43 // ParenState: Pair of an open (close) parenthesis and
     44 // its destination (source) state.
     45 //
     46 
     47 template <class A>
     48 class ParenState {
     49  public:
     50   typedef typename A::Label Label;
     51   typedef typename A::StateId StateId;
     52 
     53   struct Hash {
     54     size_t operator()(const ParenState<A> &p) const {
     55       return p.paren_id + p.state_id * kPrime;
     56     }
     57   };
     58 
     59   Label paren_id;     // ID of open (close) paren
     60   StateId state_id;   // destination (source) state of open (close) paren
     61 
     62   ParenState() : paren_id(kNoLabel), state_id(kNoStateId) {}
     63 
     64   ParenState(Label p, StateId s) : paren_id(p), state_id(s) {}
     65 
     66   bool operator==(const ParenState<A> &p) const {
     67     if (&p == this)
     68       return true;
     69     return p.paren_id == this->paren_id && p.state_id == this->state_id;
     70   }
     71 
     72   bool operator!=(const ParenState<A> &p) const { return !(p == *this); }
     73 
     74   bool operator<(const ParenState<A> &p) const {
     75     return paren_id < this->paren.id ||
     76         (p.paren_id == this->paren.id && p.state_id < this->state_id);
     77   }
     78 
     79  private:
     80   static const size_t kPrime;
     81 };
     82 
     83 template <class A>
     84 const size_t ParenState<A>::kPrime = 7853;
     85 
     86 
     87 // Creates an FST-style iterator from STL map and iterator.
     88 template <class M>
     89 class MapIterator {
     90  public:
     91   typedef typename M::const_iterator StlIterator;
     92   typedef typename M::value_type PairType;
     93   typedef typename PairType::second_type ValueType;
     94 
     95   MapIterator(const M &m, StlIterator iter)
     96       : map_(m), begin_(iter), iter_(iter) {}
     97 
     98   bool Done() const {
     99     return iter_ == map_.end() || iter_->first != begin_->first;
    100   }
    101 
    102   ValueType Value() const { return iter_->second; }
    103   void Next() { ++iter_; }
    104   void Reset() { iter_ = begin_; }
    105 
    106  private:
    107   const M &map_;
    108   StlIterator begin_;
    109   StlIterator iter_;
    110 };
    111 
    112 //
    113 // PdtParenReachable: Provides various parenthesis reachability information
    114 // on a PDT.
    115 //
    116 
    117 template <class A>
    118 class PdtParenReachable {
    119  public:
    120   typedef typename A::StateId StateId;
    121   typedef typename A::Label Label;
    122  public:
    123   // Maps from state ID to reachable paren IDs from (to) that state.
    124   typedef unordered_multimap<StateId, Label> ParenMultiMap;
    125 
    126   // Maps from paren ID and state ID to reachable state set ID
    127   typedef unordered_map<ParenState<A>, ssize_t,
    128                    typename ParenState<A>::Hash> StateSetMap;
    129 
    130   // Maps from paren ID and state ID to arcs exiting that state with that
    131   // Label.
    132   typedef unordered_multimap<ParenState<A>, A,
    133                         typename ParenState<A>::Hash> ParenArcMultiMap;
    134 
    135   typedef MapIterator<ParenMultiMap> ParenIterator;
    136 
    137   typedef MapIterator<ParenArcMultiMap> ParenArcIterator;
    138 
    139   typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
    140 
    141   // Computes close (open) parenthesis reachabilty information for
    142   // a PDT with bounded stack.
    143   PdtParenReachable(const Fst<A> &fst,
    144                     const vector<pair<Label, Label> > &parens, bool close)
    145       : fst_(fst),
    146         parens_(parens),
    147         close_(close),
    148         error_(false) {
    149     for (Label i = 0; i < parens.size(); ++i) {
    150       const pair<Label, Label>  &p = parens[i];
    151       paren_id_map_[p.first] = i;
    152       paren_id_map_[p.second] = i;
    153     }
    154 
    155     if (close_) {
    156       StateId start = fst.Start();
    157       if (start == kNoStateId)
    158         return;
    159       if (!DFSearch(start)) {
    160         FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
    161         error_ = true;
    162       }
    163     } else {
    164       FSTERROR() << "PdtParenReachable: open paren info not implemented";
    165       error_ = true;
    166     }
    167   }
    168 
    169   bool const Error() { return error_; }
    170 
    171   // Given a state ID, returns an iterator over paren IDs
    172   // for close (open) parens reachable from that state along balanced
    173   // paths.
    174   ParenIterator FindParens(StateId s) const {
    175     return ParenIterator(paren_multimap_, paren_multimap_.find(s));
    176   }
    177 
    178   // Given a paren ID and a state ID s, returns an iterator over
    179   // states that can be reached along balanced paths from (to) s that
    180   // have have close (open) parentheses matching the paren ID exiting
    181   // (entering) those states.
    182   SetIterator FindStates(Label paren_id, StateId s) const {
    183     ParenState<A> paren_state(paren_id, s);
    184     typename StateSetMap::const_iterator id_it = set_map_.find(paren_state);
    185     if (id_it == set_map_.end()) {
    186       return state_sets_.FindSet(-1);
    187     } else {
    188       return state_sets_.FindSet(id_it->second);
    189     }
    190   }
    191 
    192   // Given a paren Id and a state ID s, return an iterator over
    193   // arcs that exit (enter) s and are labeled with a close (open)
    194   // parenthesis matching the paren ID.
    195   ParenArcIterator FindParenArcs(Label paren_id, StateId s) const {
    196     ParenState<A> paren_state(paren_id, s);
    197     return ParenArcIterator(paren_arc_multimap_,
    198                             paren_arc_multimap_.find(paren_state));
    199   }
    200 
    201  private:
    202   // DFS that gathers paren and state set information.
    203   // Bool returns false when cycle detected.
    204   bool DFSearch(StateId s);
    205 
    206   // Unions state sets together gathered by the DFS.
    207   void ComputeStateSet(StateId s);
    208 
    209   // Gather state set(s) from state 'nexts'.
    210   void UpdateStateSet(StateId nexts, set<Label> *paren_set,
    211                       vector< set<StateId> > *state_sets) const;
    212 
    213   const Fst<A> &fst_;
    214   const vector<pair<Label, Label> > &parens_;         // Paren ID -> Labels
    215   bool close_;                                        // Close/open paren info?
    216   unordered_map<Label, Label> paren_id_map_;               // Paren labels -> ID
    217   ParenMultiMap paren_multimap_;                      // Paren reachability
    218   ParenArcMultiMap paren_arc_multimap_;               // Paren Arcs
    219   vector<char> state_color_;                          // DFS state
    220   mutable Collection<ssize_t, StateId> state_sets_;   // Reachable states -> ID
    221   StateSetMap set_map_;                               // ID -> Reachable states
    222   bool error_;
    223   DISALLOW_COPY_AND_ASSIGN(PdtParenReachable);
    224 };
    225 
    226 // DFS that gathers paren and state set information.
    227 template <class A>
    228 bool PdtParenReachable<A>::DFSearch(StateId s) {
    229   if (s >= state_color_.size())
    230     state_color_.resize(s + 1, kDfsWhite);
    231 
    232   if (state_color_[s] == kDfsBlack)
    233     return true;
    234 
    235   if (state_color_[s] == kDfsGrey)
    236     return false;
    237 
    238   state_color_[s] = kDfsGrey;
    239 
    240   for (ArcIterator<Fst<A> > aiter(fst_, s);
    241        !aiter.Done();
    242        aiter.Next()) {
    243     const A &arc = aiter.Value();
    244 
    245     typename unordered_map<Label, Label>::const_iterator pit
    246         = paren_id_map_.find(arc.ilabel);
    247     if (pit != paren_id_map_.end()) {               // paren?
    248       Label paren_id = pit->second;
    249       if (arc.ilabel == parens_[paren_id].first) {  // open paren
    250         if (!DFSearch(arc.nextstate))
    251           return false;
    252         for (SetIterator set_iter = FindStates(paren_id, arc.nextstate);
    253              !set_iter.Done(); set_iter.Next()) {
    254           for (ParenArcIterator paren_arc_iter =
    255                    FindParenArcs(paren_id, set_iter.Element());
    256                !paren_arc_iter.Done();
    257                paren_arc_iter.Next()) {
    258             const A &cparc = paren_arc_iter.Value();
    259             if (!DFSearch(cparc.nextstate))
    260               return false;
    261           }
    262         }
    263       }
    264     } else {                                       // non-paren
    265       if(!DFSearch(arc.nextstate))
    266         return false;
    267     }
    268   }
    269   ComputeStateSet(s);
    270   state_color_[s] = kDfsBlack;
    271   return true;
    272 }
    273 
    274 // Unions state sets together gathered by the DFS.
    275 template <class A>
    276 void PdtParenReachable<A>::ComputeStateSet(StateId s) {
    277   set<Label> paren_set;
    278   vector< set<StateId> > state_sets(parens_.size());
    279   for (ArcIterator< Fst<A> > aiter(fst_, s);
    280        !aiter.Done();
    281        aiter.Next()) {
    282     const A &arc = aiter.Value();
    283 
    284     typename unordered_map<Label, Label>::const_iterator pit
    285         = paren_id_map_.find(arc.ilabel);
    286     if (pit != paren_id_map_.end()) {               // paren?
    287       Label paren_id = pit->second;
    288       if (arc.ilabel == parens_[paren_id].first) {  // open paren
    289         for (SetIterator set_iter =
    290                  FindStates(paren_id, arc.nextstate);
    291              !set_iter.Done(); set_iter.Next()) {
    292           for (ParenArcIterator paren_arc_iter =
    293                    FindParenArcs(paren_id, set_iter.Element());
    294                !paren_arc_iter.Done();
    295                paren_arc_iter.Next()) {
    296             const A &cparc = paren_arc_iter.Value();
    297             UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
    298           }
    299         }
    300       } else {                                      // close paren
    301         paren_set.insert(paren_id);
    302         state_sets[paren_id].insert(s);
    303         ParenState<A> paren_state(paren_id, s);
    304         paren_arc_multimap_.insert(make_pair(paren_state, arc));
    305       }
    306     } else {                                        // non-paren
    307       UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
    308     }
    309   }
    310 
    311   vector<StateId> state_set;
    312   for (typename set<Label>::iterator paren_iter = paren_set.begin();
    313        paren_iter != paren_set.end(); ++paren_iter) {
    314     state_set.clear();
    315     Label paren_id = *paren_iter;
    316     paren_multimap_.insert(make_pair(s, paren_id));
    317     for (typename set<StateId>::iterator state_iter
    318              = state_sets[paren_id].begin();
    319          state_iter != state_sets[paren_id].end();
    320          ++state_iter) {
    321       state_set.push_back(*state_iter);
    322     }
    323     ParenState<A> paren_state(paren_id, s);
    324     set_map_[paren_state] = state_sets_.FindId(state_set);
    325   }
    326 }
    327 
    328 // Gather state set(s) from state 'nexts'.
    329 template <class A>
    330 void PdtParenReachable<A>::UpdateStateSet(
    331     StateId nexts, set<Label> *paren_set,
    332     vector< set<StateId> > *state_sets) const {
    333   for(ParenIterator paren_iter = FindParens(nexts);
    334       !paren_iter.Done(); paren_iter.Next()) {
    335     Label paren_id = paren_iter.Value();
    336     paren_set->insert(paren_id);
    337     for (SetIterator set_iter = FindStates(paren_id, nexts);
    338          !set_iter.Done(); set_iter.Next()) {
    339       (*state_sets)[paren_id].insert(set_iter.Element());
    340     }
    341   }
    342 }
    343 
    344 
    345 // Store balancing parenthesis data for a PDT. Allows on-the-fly
    346 // construction (e.g. in PdtShortestPath) unlike PdtParenReachable above.
    347 template <class A>
    348 class PdtBalanceData {
    349  public:
    350   typedef typename A::StateId StateId;
    351   typedef typename A::Label Label;
    352 
    353   // Hash set for open parens
    354   typedef unordered_set<ParenState<A>, typename ParenState<A>::Hash> OpenParenSet;
    355 
    356   // Maps from open paren destination state to parenthesis ID.
    357   typedef unordered_multimap<StateId, Label> OpenParenMap;
    358 
    359   // Maps from open paren state to source states of matching close parens
    360   typedef unordered_multimap<ParenState<A>, StateId,
    361                         typename ParenState<A>::Hash> CloseParenMap;
    362 
    363   // Maps from open paren state to close source set ID
    364   typedef unordered_map<ParenState<A>, ssize_t,
    365                    typename ParenState<A>::Hash> CloseSourceMap;
    366 
    367   typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
    368 
    369   PdtBalanceData() {}
    370 
    371   void Clear() {
    372     open_paren_map_.clear();
    373     close_paren_map_.clear();
    374   }
    375 
    376   // Adds an open parenthesis with destination state 'open_dest'.
    377   void OpenInsert(Label paren_id, StateId open_dest) {
    378     ParenState<A> key(paren_id, open_dest);
    379     if (!open_paren_set_.count(key)) {
    380       open_paren_set_.insert(key);
    381       open_paren_map_.insert(make_pair(open_dest, paren_id));
    382     }
    383   }
    384 
    385   // Adds a matching closing parenthesis with source state
    386   // 'close_source' that balances an open_parenthesis with destination
    387   // state 'open_dest' if OpenInsert() previously called
    388   // (o.w. CloseInsert() does nothing).
    389   void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) {
    390     ParenState<A> key(paren_id, open_dest);
    391     if (open_paren_set_.count(key))
    392       close_paren_map_.insert(make_pair(key, close_source));
    393   }
    394 
    395   // Find close paren source states matching an open parenthesis.
    396   // Methods that follow, iterate through those matching states.
    397   // Should be called only after FinishInsert(open_dest).
    398   SetIterator Find(Label paren_id, StateId open_dest) {
    399     ParenState<A> close_key(paren_id, open_dest);
    400     typename CloseSourceMap::const_iterator id_it =
    401         close_source_map_.find(close_key);
    402     if (id_it == close_source_map_.end()) {
    403       return close_source_sets_.FindSet(-1);
    404     } else {
    405       return close_source_sets_.FindSet(id_it->second);
    406     }
    407   }
    408 
    409   // Call when all open and close parenthesis insertions wrt open
    410   // parentheses entering 'open_dest' are finished. Must be called
    411   // before Find(open_dest). Stores close paren source state sets
    412   // efficiently.
    413   void FinishInsert(StateId open_dest) {
    414     vector<StateId> close_sources;
    415     for (typename OpenParenMap::iterator oit = open_paren_map_.find(open_dest);
    416          oit != open_paren_map_.end() && oit->first == open_dest;) {
    417       Label paren_id = oit->second;
    418       close_sources.clear();
    419       ParenState<A> okey(paren_id, open_dest);
    420       open_paren_set_.erase(open_paren_set_.find(okey));
    421       for (typename CloseParenMap::iterator cit = close_paren_map_.find(okey);
    422            cit != close_paren_map_.end() && cit->first == okey;) {
    423         close_sources.push_back(cit->second);
    424         close_paren_map_.erase(cit++);
    425       }
    426       sort(close_sources.begin(), close_sources.end());
    427       typename vector<StateId>::iterator unique_end =
    428           unique(close_sources.begin(), close_sources.end());
    429       close_sources.resize(unique_end - close_sources.begin());
    430 
    431       if (!close_sources.empty())
    432         close_source_map_[okey] = close_source_sets_.FindId(close_sources);
    433       open_paren_map_.erase(oit++);
    434     }
    435   }
    436 
    437   // Return a new balance data object representing the reversed balance
    438   // information.
    439   PdtBalanceData<A> *Reverse(StateId num_states,
    440                                StateId num_split,
    441                                StateId state_id_shift) const;
    442 
    443  private:
    444   OpenParenSet open_paren_set_;                      // open par. at dest?
    445 
    446   OpenParenMap open_paren_map_;                      // open parens per state
    447   ParenState<A> open_dest_;                          // cur open dest. state
    448   typename OpenParenMap::const_iterator open_iter_;  // cur open parens/state
    449 
    450   CloseParenMap close_paren_map_;                    // close states/open
    451                                                      //  paren and state
    452 
    453   CloseSourceMap close_source_map_;                  // paren, state to set ID
    454   mutable Collection<ssize_t, StateId> close_source_sets_;
    455 };
    456 
    457 // Return a new balance data object representing the reversed balance
    458 // information.
    459 template <class A>
    460 PdtBalanceData<A> *PdtBalanceData<A>::Reverse(
    461     StateId num_states,
    462     StateId num_split,
    463     StateId state_id_shift) const {
    464   PdtBalanceData<A> *bd = new PdtBalanceData<A>;
    465   unordered_set<StateId> close_sources;
    466   StateId split_size = num_states / num_split;
    467 
    468   for (StateId i = 0; i < num_states; i+= split_size) {
    469     close_sources.clear();
    470 
    471     for (typename CloseSourceMap::const_iterator
    472              sit = close_source_map_.begin();
    473          sit != close_source_map_.end();
    474          ++sit) {
    475       ParenState<A> okey = sit->first;
    476       StateId open_dest = okey.state_id;
    477       Label paren_id = okey.paren_id;
    478       for (SetIterator set_iter = close_source_sets_.FindSet(sit->second);
    479            !set_iter.Done(); set_iter.Next()) {
    480         StateId close_source = set_iter.Element();
    481         if ((close_source < i) || (close_source >= i + split_size))
    482           continue;
    483         close_sources.insert(close_source + state_id_shift);
    484         bd->OpenInsert(paren_id, close_source + state_id_shift);
    485         bd->CloseInsert(paren_id, close_source + state_id_shift,
    486                         open_dest + state_id_shift);
    487       }
    488     }
    489 
    490     for (typename unordered_set<StateId>::const_iterator it
    491              = close_sources.begin();
    492          it != close_sources.end();
    493          ++it) {
    494       bd->FinishInsert(*it);
    495     }
    496 
    497   }
    498   return bd;
    499 }
    500 
    501 
    502 }  // namespace fst
    503 
    504 #endif  // FST_EXTENSIONS_PDT_PAREN_H_
    505