1 // paren.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley (at) google.com (Michael Riley) 17 // 18 // Common classes for PDT parentheses 19 20 // \file 21 22 #ifndef FST_EXTENSIONS_PDT_PAREN_H_ 23 #define FST_EXTENSIONS_PDT_PAREN_H_ 24 25 #include <algorithm> 26 #include <unordered_map> 27 using std::tr1::unordered_map; 28 using std::tr1::unordered_multimap; 29 #include <tr1/unordered_set> 30 using std::tr1::unordered_set; 31 using std::tr1::unordered_multiset; 32 #include <set> 33 34 #include <fst/extensions/pdt/pdt.h> 35 #include <fst/extensions/pdt/collection.h> 36 #include <fst/fst.h> 37 #include <fst/dfs-visit.h> 38 39 40 namespace fst { 41 42 // 43 // ParenState: Pair of an open (close) parenthesis and 44 // its destination (source) state. 45 // 46 47 template <class A> 48 class ParenState { 49 public: 50 typedef typename A::Label Label; 51 typedef typename A::StateId StateId; 52 53 struct Hash { 54 size_t operator()(const ParenState<A> &p) const { 55 return p.paren_id + p.state_id * kPrime; 56 } 57 }; 58 59 Label paren_id; // ID of open (close) paren 60 StateId state_id; // destination (source) state of open (close) paren 61 62 ParenState() : paren_id(kNoLabel), state_id(kNoStateId) {} 63 64 ParenState(Label p, StateId s) : paren_id(p), state_id(s) {} 65 66 bool operator==(const ParenState<A> &p) const { 67 if (&p == this) 68 return true; 69 return p.paren_id == this->paren_id && p.state_id == this->state_id; 70 } 71 72 bool operator!=(const ParenState<A> &p) const { return !(p == *this); } 73 74 bool operator<(const ParenState<A> &p) const { 75 return paren_id < this->paren.id || 76 (p.paren_id == this->paren.id && p.state_id < this->state_id); 77 } 78 79 private: 80 static const size_t kPrime; 81 }; 82 83 template <class A> 84 const size_t ParenState<A>::kPrime = 7853; 85 86 87 // Creates an FST-style iterator from STL map and iterator. 88 template <class M> 89 class MapIterator { 90 public: 91 typedef typename M::const_iterator StlIterator; 92 typedef typename M::value_type PairType; 93 typedef typename PairType::second_type ValueType; 94 95 MapIterator(const M &m, StlIterator iter) 96 : map_(m), begin_(iter), iter_(iter) {} 97 98 bool Done() const { 99 return iter_ == map_.end() || iter_->first != begin_->first; 100 } 101 102 ValueType Value() const { return iter_->second; } 103 void Next() { ++iter_; } 104 void Reset() { iter_ = begin_; } 105 106 private: 107 const M &map_; 108 StlIterator begin_; 109 StlIterator iter_; 110 }; 111 112 // 113 // PdtParenReachable: Provides various parenthesis reachability information 114 // on a PDT. 115 // 116 117 template <class A> 118 class PdtParenReachable { 119 public: 120 typedef typename A::StateId StateId; 121 typedef typename A::Label Label; 122 public: 123 // Maps from state ID to reachable paren IDs from (to) that state. 124 typedef unordered_multimap<StateId, Label> ParenMultiMap; 125 126 // Maps from paren ID and state ID to reachable state set ID 127 typedef unordered_map<ParenState<A>, ssize_t, 128 typename ParenState<A>::Hash> StateSetMap; 129 130 // Maps from paren ID and state ID to arcs exiting that state with that 131 // Label. 132 typedef unordered_multimap<ParenState<A>, A, 133 typename ParenState<A>::Hash> ParenArcMultiMap; 134 135 typedef MapIterator<ParenMultiMap> ParenIterator; 136 137 typedef MapIterator<ParenArcMultiMap> ParenArcIterator; 138 139 typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator; 140 141 // Computes close (open) parenthesis reachabilty information for 142 // a PDT with bounded stack. 143 PdtParenReachable(const Fst<A> &fst, 144 const vector<pair<Label, Label> > &parens, bool close) 145 : fst_(fst), 146 parens_(parens), 147 close_(close) { 148 for (Label i = 0; i < parens.size(); ++i) { 149 const pair<Label, Label> &p = parens[i]; 150 paren_id_map_[p.first] = i; 151 paren_id_map_[p.second] = i; 152 } 153 154 if (close_) { 155 StateId start = fst.Start(); 156 if (start == kNoStateId) 157 return; 158 DFSearch(start, start); 159 } else { 160 FSTERROR() << "PdtParenReachable: open paren info not implemented"; 161 } 162 } 163 164 // Given a state ID, returns an iterator over paren IDs 165 // for close (open) parens reachable from that state along balanced 166 // paths. 167 ParenIterator FindParens(StateId s) const { 168 return ParenIterator(paren_multimap_, paren_multimap_.find(s)); 169 } 170 171 // Given a paren ID and a state ID s, returns an iterator over 172 // states that can be reached along balanced paths from (to) s that 173 // have have close (open) parentheses matching the paren ID exiting 174 // (entering) those states. 175 SetIterator FindStates(Label paren_id, StateId s) const { 176 ParenState<A> paren_state(paren_id, s); 177 typename StateSetMap::const_iterator id_it = set_map_.find(paren_state); 178 if (id_it == set_map_.end()) { 179 return state_sets_.FindSet(-1); 180 } else { 181 return state_sets_.FindSet(id_it->second); 182 } 183 } 184 185 // Given a paren Id and a state ID s, return an iterator over 186 // arcs that exit (enter) s and are labeled with a close (open) 187 // parenthesis matching the paren ID. 188 ParenArcIterator FindParenArcs(Label paren_id, StateId s) const { 189 ParenState<A> paren_state(paren_id, s); 190 return ParenArcIterator(paren_arc_multimap_, 191 paren_arc_multimap_.find(paren_state)); 192 } 193 194 private: 195 // DFS that gathers paren and state set information. 196 // Bool returns false when cycle detected. 197 bool DFSearch(StateId s, StateId start); 198 199 // Unions state sets together gathered by the DFS. 200 void ComputeStateSet(StateId s); 201 202 // Gather state set(s) from state 'nexts'. 203 void UpdateStateSet(StateId nexts, set<Label> *paren_set, 204 vector< set<StateId> > *state_sets) const; 205 206 const Fst<A> &fst_; 207 const vector<pair<Label, Label> > &parens_; // Paren ID -> Labels 208 bool close_; // Close/open paren info? 209 unordered_map<Label, Label> paren_id_map_; // Paren labels -> ID 210 ParenMultiMap paren_multimap_; // Paren reachability 211 ParenArcMultiMap paren_arc_multimap_; // Paren Arcs 212 vector<char> state_color_; // DFS state 213 mutable Collection<ssize_t, StateId> state_sets_; // Reachable states -> ID 214 StateSetMap set_map_; // ID -> Reachable states 215 DISALLOW_COPY_AND_ASSIGN(PdtParenReachable); 216 }; 217 218 // DFS that gathers paren and state set information. 219 template <class A> 220 bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) { 221 if (s >= state_color_.size()) 222 state_color_.resize(s + 1, kDfsWhite); 223 224 if (state_color_[s] == kDfsBlack) 225 return true; 226 227 if (state_color_[s] == kDfsGrey) 228 return false; 229 230 state_color_[s] = kDfsGrey; 231 232 for (ArcIterator<Fst<A> > aiter(fst_, s); 233 !aiter.Done(); 234 aiter.Next()) { 235 const A &arc = aiter.Value(); 236 237 typename unordered_map<Label, Label>::const_iterator pit 238 = paren_id_map_.find(arc.ilabel); 239 if (pit != paren_id_map_.end()) { // paren? 240 Label paren_id = pit->second; 241 if (arc.ilabel == parens_[paren_id].first) { // open paren 242 DFSearch(arc.nextstate, arc.nextstate); 243 for (SetIterator set_iter = FindStates(paren_id, arc.nextstate); 244 !set_iter.Done(); set_iter.Next()) { 245 for (ParenArcIterator paren_arc_iter = 246 FindParenArcs(paren_id, set_iter.Element()); 247 !paren_arc_iter.Done(); 248 paren_arc_iter.Next()) { 249 const A &cparc = paren_arc_iter.Value(); 250 DFSearch(cparc.nextstate, start); 251 } 252 } 253 } 254 } else { // non-paren 255 if(!DFSearch(arc.nextstate, start)) { 256 FSTERROR() << "PdtReachable: Underlying cyclicity not supported"; 257 return true; 258 } 259 } 260 } 261 ComputeStateSet(s); 262 state_color_[s] = kDfsBlack; 263 return true; 264 } 265 266 // Unions state sets together gathered by the DFS. 267 template <class A> 268 void PdtParenReachable<A>::ComputeStateSet(StateId s) { 269 set<Label> paren_set; 270 vector< set<StateId> > state_sets(parens_.size()); 271 for (ArcIterator< Fst<A> > aiter(fst_, s); 272 !aiter.Done(); 273 aiter.Next()) { 274 const A &arc = aiter.Value(); 275 276 typename unordered_map<Label, Label>::const_iterator pit 277 = paren_id_map_.find(arc.ilabel); 278 if (pit != paren_id_map_.end()) { // paren? 279 Label paren_id = pit->second; 280 if (arc.ilabel == parens_[paren_id].first) { // open paren 281 for (SetIterator set_iter = 282 FindStates(paren_id, arc.nextstate); 283 !set_iter.Done(); set_iter.Next()) { 284 for (ParenArcIterator paren_arc_iter = 285 FindParenArcs(paren_id, set_iter.Element()); 286 !paren_arc_iter.Done(); 287 paren_arc_iter.Next()) { 288 const A &cparc = paren_arc_iter.Value(); 289 UpdateStateSet(cparc.nextstate, &paren_set, &state_sets); 290 } 291 } 292 } else { // close paren 293 paren_set.insert(paren_id); 294 state_sets[paren_id].insert(s); 295 ParenState<A> paren_state(paren_id, s); 296 paren_arc_multimap_.insert(make_pair(paren_state, arc)); 297 } 298 } else { // non-paren 299 UpdateStateSet(arc.nextstate, &paren_set, &state_sets); 300 } 301 } 302 303 vector<StateId> state_set; 304 for (typename set<Label>::iterator paren_iter = paren_set.begin(); 305 paren_iter != paren_set.end(); ++paren_iter) { 306 state_set.clear(); 307 Label paren_id = *paren_iter; 308 paren_multimap_.insert(make_pair(s, paren_id)); 309 for (typename set<StateId>::iterator state_iter 310 = state_sets[paren_id].begin(); 311 state_iter != state_sets[paren_id].end(); 312 ++state_iter) { 313 state_set.push_back(*state_iter); 314 } 315 ParenState<A> paren_state(paren_id, s); 316 set_map_[paren_state] = state_sets_.FindId(state_set); 317 } 318 } 319 320 // Gather state set(s) from state 'nexts'. 321 template <class A> 322 void PdtParenReachable<A>::UpdateStateSet( 323 StateId nexts, set<Label> *paren_set, 324 vector< set<StateId> > *state_sets) const { 325 for(ParenIterator paren_iter = FindParens(nexts); 326 !paren_iter.Done(); paren_iter.Next()) { 327 Label paren_id = paren_iter.Value(); 328 paren_set->insert(paren_id); 329 for (SetIterator set_iter = FindStates(paren_id, nexts); 330 !set_iter.Done(); set_iter.Next()) { 331 (*state_sets)[paren_id].insert(set_iter.Element()); 332 } 333 } 334 } 335 336 337 // Store balancing parenthesis data for a PDT. Allows on-the-fly 338 // construction (e.g. in PdtShortestPath) unlike PdtParenReachable above. 339 template <class A> 340 class PdtBalanceData { 341 public: 342 typedef typename A::StateId StateId; 343 typedef typename A::Label Label; 344 345 // Hash set for open parens 346 typedef unordered_set<ParenState<A>, typename ParenState<A>::Hash> OpenParenSet; 347 348 // Maps from open paren destination state to parenthesis ID. 349 typedef unordered_multimap<StateId, Label> OpenParenMap; 350 351 // Maps from open paren state to source states of matching close parens 352 typedef unordered_multimap<ParenState<A>, StateId, 353 typename ParenState<A>::Hash> CloseParenMap; 354 355 // Maps from open paren state to close source set ID 356 typedef unordered_map<ParenState<A>, ssize_t, 357 typename ParenState<A>::Hash> CloseSourceMap; 358 359 typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator; 360 361 PdtBalanceData() {} 362 363 void Clear() { 364 open_paren_map_.clear(); 365 close_paren_map_.clear(); 366 } 367 368 // Adds an open parenthesis with destination state 'open_dest'. 369 void OpenInsert(Label paren_id, StateId open_dest) { 370 ParenState<A> key(paren_id, open_dest); 371 if (!open_paren_set_.count(key)) { 372 open_paren_set_.insert(key); 373 open_paren_map_.insert(make_pair(open_dest, paren_id)); 374 } 375 } 376 377 // Adds a matching closing parenthesis with source state 378 // 'close_source' that balances an open_parenthesis with destination 379 // state 'open_dest' if OpenInsert() previously called 380 // (o.w. CloseInsert() does nothing). 381 void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) { 382 ParenState<A> key(paren_id, open_dest); 383 if (open_paren_set_.count(key)) 384 close_paren_map_.insert(make_pair(key, close_source)); 385 } 386 387 // Find close paren source states matching an open parenthesis. 388 // Methods that follow, iterate through those matching states. 389 // Should be called only after FinishInsert(open_dest). 390 SetIterator Find(Label paren_id, StateId open_dest) { 391 ParenState<A> close_key(paren_id, open_dest); 392 typename CloseSourceMap::const_iterator id_it = 393 close_source_map_.find(close_key); 394 if (id_it == close_source_map_.end()) { 395 return close_source_sets_.FindSet(-1); 396 } else { 397 return close_source_sets_.FindSet(id_it->second); 398 } 399 } 400 401 // Call when all open and close parenthesis insertions wrt open 402 // parentheses entering 'open_dest' are finished. Must be called 403 // before Find(open_dest). Stores close paren source state sets 404 // efficiently. 405 void FinishInsert(StateId open_dest) { 406 vector<StateId> close_sources; 407 for (typename OpenParenMap::iterator oit = open_paren_map_.find(open_dest); 408 oit != open_paren_map_.end() && oit->first == open_dest;) { 409 Label paren_id = oit->second; 410 close_sources.clear(); 411 ParenState<A> okey(paren_id, open_dest); 412 open_paren_set_.erase(open_paren_set_.find(okey)); 413 for (typename CloseParenMap::iterator cit = close_paren_map_.find(okey); 414 cit != close_paren_map_.end() && cit->first == okey;) { 415 close_sources.push_back(cit->second); 416 close_paren_map_.erase(cit++); 417 } 418 sort(close_sources.begin(), close_sources.end()); 419 typename vector<StateId>::iterator unique_end = 420 unique(close_sources.begin(), close_sources.end()); 421 close_sources.resize(unique_end - close_sources.begin()); 422 423 if (!close_sources.empty()) 424 close_source_map_[okey] = close_source_sets_.FindId(close_sources); 425 open_paren_map_.erase(oit++); 426 } 427 } 428 429 // Return a new balance data object representing the reversed balance 430 // information. 431 PdtBalanceData<A> *Reverse(StateId num_states, 432 StateId num_split, 433 StateId state_id_shift) const; 434 435 private: 436 OpenParenSet open_paren_set_; // open par. at dest? 437 438 OpenParenMap open_paren_map_; // open parens per state 439 ParenState<A> open_dest_; // cur open dest. state 440 typename OpenParenMap::const_iterator open_iter_; // cur open parens/state 441 442 CloseParenMap close_paren_map_; // close states/open 443 // paren and state 444 445 CloseSourceMap close_source_map_; // paren, state to set ID 446 mutable Collection<ssize_t, StateId> close_source_sets_; 447 }; 448 449 // Return a new balance data object representing the reversed balance 450 // information. 451 template <class A> 452 PdtBalanceData<A> *PdtBalanceData<A>::Reverse( 453 StateId num_states, 454 StateId num_split, 455 StateId state_id_shift) const { 456 PdtBalanceData<A> *bd = new PdtBalanceData<A>; 457 unordered_set<StateId> close_sources; 458 StateId split_size = num_states / num_split; 459 460 for (StateId i = 0; i < num_states; i+= split_size) { 461 close_sources.clear(); 462 463 for (typename CloseSourceMap::const_iterator 464 sit = close_source_map_.begin(); 465 sit != close_source_map_.end(); 466 ++sit) { 467 ParenState<A> okey = sit->first; 468 StateId open_dest = okey.state_id; 469 Label paren_id = okey.paren_id; 470 for (SetIterator set_iter = close_source_sets_.FindSet(sit->second); 471 !set_iter.Done(); set_iter.Next()) { 472 StateId close_source = set_iter.Element(); 473 if ((close_source < i) || (close_source >= i + split_size)) 474 continue; 475 close_sources.insert(close_source + state_id_shift); 476 bd->OpenInsert(paren_id, close_source + state_id_shift); 477 bd->CloseInsert(paren_id, close_source + state_id_shift, 478 open_dest + state_id_shift); 479 } 480 } 481 482 for (typename unordered_set<StateId>::const_iterator it 483 = close_sources.begin(); 484 it != close_sources.end(); 485 ++it) { 486 bd->FinishInsert(*it); 487 } 488 489 } 490 return bd; 491 } 492 493 494 } // namespace fst 495 496 #endif // FST_EXTENSIONS_PDT_PAREN_H_ 497