1 // paren.h 2 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // Copyright 2005-2010 Google, Inc. 16 // Author: riley (at) google.com (Michael Riley) 17 // 18 // Common classes for PDT parentheses 19 20 // \file 21 22 #ifndef FST_EXTENSIONS_PDT_PAREN_H_ 23 #define FST_EXTENSIONS_PDT_PAREN_H_ 24 25 #include <algorithm> 26 #include <tr1/unordered_map> 27 using std::tr1::unordered_map; 28 using std::tr1::unordered_multimap; 29 #include <tr1/unordered_set> 30 using std::tr1::unordered_set; 31 using std::tr1::unordered_multiset; 32 #include <set> 33 34 #include <fst/extensions/pdt/pdt.h> 35 #include <fst/extensions/pdt/collection.h> 36 #include <fst/fst.h> 37 #include <fst/dfs-visit.h> 38 39 40 namespace fst { 41 42 // 43 // ParenState: Pair of an open (close) parenthesis and 44 // its destination (source) state. 45 // 46 47 template <class A> 48 class ParenState { 49 public: 50 typedef typename A::Label Label; 51 typedef typename A::StateId StateId; 52 53 struct Hash { 54 size_t operator()(const ParenState<A> &p) const { 55 return p.paren_id + p.state_id * kPrime; 56 } 57 }; 58 59 Label paren_id; // ID of open (close) paren 60 StateId state_id; // destination (source) state of open (close) paren 61 62 ParenState() : paren_id(kNoLabel), state_id(kNoStateId) {} 63 64 ParenState(Label p, StateId s) : paren_id(p), state_id(s) {} 65 66 bool operator==(const ParenState<A> &p) const { 67 if (&p == this) 68 return true; 69 return p.paren_id == this->paren_id && p.state_id == this->state_id; 70 } 71 72 bool operator!=(const ParenState<A> &p) const { return !(p == *this); } 73 74 bool operator<(const ParenState<A> &p) const { 75 return paren_id < this->paren.id || 76 (p.paren_id == this->paren.id && p.state_id < this->state_id); 77 } 78 79 private: 80 static const size_t kPrime; 81 }; 82 83 template <class A> 84 const size_t ParenState<A>::kPrime = 7853; 85 86 87 // Creates an FST-style iterator from STL map and iterator. 88 template <class M> 89 class MapIterator { 90 public: 91 typedef typename M::const_iterator StlIterator; 92 typedef typename M::value_type PairType; 93 typedef typename PairType::second_type ValueType; 94 95 MapIterator(const M &m, StlIterator iter) 96 : map_(m), begin_(iter), iter_(iter) {} 97 98 bool Done() const { 99 return iter_ == map_.end() || iter_->first != begin_->first; 100 } 101 102 ValueType Value() const { return iter_->second; } 103 void Next() { ++iter_; } 104 void Reset() { iter_ = begin_; } 105 106 private: 107 const M &map_; 108 StlIterator begin_; 109 StlIterator iter_; 110 }; 111 112 // 113 // PdtParenReachable: Provides various parenthesis reachability information 114 // on a PDT. 115 // 116 117 template <class A> 118 class PdtParenReachable { 119 public: 120 typedef typename A::StateId StateId; 121 typedef typename A::Label Label; 122 public: 123 // Maps from state ID to reachable paren IDs from (to) that state. 124 typedef unordered_multimap<StateId, Label> ParenMultiMap; 125 126 // Maps from paren ID and state ID to reachable state set ID 127 typedef unordered_map<ParenState<A>, ssize_t, 128 typename ParenState<A>::Hash> StateSetMap; 129 130 // Maps from paren ID and state ID to arcs exiting that state with that 131 // Label. 132 typedef unordered_multimap<ParenState<A>, A, 133 typename ParenState<A>::Hash> ParenArcMultiMap; 134 135 typedef MapIterator<ParenMultiMap> ParenIterator; 136 137 typedef MapIterator<ParenArcMultiMap> ParenArcIterator; 138 139 typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator; 140 141 // Computes close (open) parenthesis reachabilty information for 142 // a PDT with bounded stack. 143 PdtParenReachable(const Fst<A> &fst, 144 const vector<pair<Label, Label> > &parens, bool close) 145 : fst_(fst), 146 parens_(parens), 147 close_(close), 148 error_(false) { 149 for (Label i = 0; i < parens.size(); ++i) { 150 const pair<Label, Label> &p = parens[i]; 151 paren_id_map_[p.first] = i; 152 paren_id_map_[p.second] = i; 153 } 154 155 if (close_) { 156 StateId start = fst.Start(); 157 if (start == kNoStateId) 158 return; 159 if (!DFSearch(start)) { 160 FSTERROR() << "PdtReachable: Underlying cyclicity not supported"; 161 error_ = true; 162 } 163 } else { 164 FSTERROR() << "PdtParenReachable: open paren info not implemented"; 165 error_ = true; 166 } 167 } 168 169 bool const Error() { return error_; } 170 171 // Given a state ID, returns an iterator over paren IDs 172 // for close (open) parens reachable from that state along balanced 173 // paths. 174 ParenIterator FindParens(StateId s) const { 175 return ParenIterator(paren_multimap_, paren_multimap_.find(s)); 176 } 177 178 // Given a paren ID and a state ID s, returns an iterator over 179 // states that can be reached along balanced paths from (to) s that 180 // have have close (open) parentheses matching the paren ID exiting 181 // (entering) those states. 182 SetIterator FindStates(Label paren_id, StateId s) const { 183 ParenState<A> paren_state(paren_id, s); 184 typename StateSetMap::const_iterator id_it = set_map_.find(paren_state); 185 if (id_it == set_map_.end()) { 186 return state_sets_.FindSet(-1); 187 } else { 188 return state_sets_.FindSet(id_it->second); 189 } 190 } 191 192 // Given a paren Id and a state ID s, return an iterator over 193 // arcs that exit (enter) s and are labeled with a close (open) 194 // parenthesis matching the paren ID. 195 ParenArcIterator FindParenArcs(Label paren_id, StateId s) const { 196 ParenState<A> paren_state(paren_id, s); 197 return ParenArcIterator(paren_arc_multimap_, 198 paren_arc_multimap_.find(paren_state)); 199 } 200 201 private: 202 // DFS that gathers paren and state set information. 203 // Bool returns false when cycle detected. 204 bool DFSearch(StateId s); 205 206 // Unions state sets together gathered by the DFS. 207 void ComputeStateSet(StateId s); 208 209 // Gather state set(s) from state 'nexts'. 210 void UpdateStateSet(StateId nexts, set<Label> *paren_set, 211 vector< set<StateId> > *state_sets) const; 212 213 const Fst<A> &fst_; 214 const vector<pair<Label, Label> > &parens_; // Paren ID -> Labels 215 bool close_; // Close/open paren info? 216 unordered_map<Label, Label> paren_id_map_; // Paren labels -> ID 217 ParenMultiMap paren_multimap_; // Paren reachability 218 ParenArcMultiMap paren_arc_multimap_; // Paren Arcs 219 vector<char> state_color_; // DFS state 220 mutable Collection<ssize_t, StateId> state_sets_; // Reachable states -> ID 221 StateSetMap set_map_; // ID -> Reachable states 222 bool error_; 223 DISALLOW_COPY_AND_ASSIGN(PdtParenReachable); 224 }; 225 226 // DFS that gathers paren and state set information. 227 template <class A> 228 bool PdtParenReachable<A>::DFSearch(StateId s) { 229 if (s >= state_color_.size()) 230 state_color_.resize(s + 1, kDfsWhite); 231 232 if (state_color_[s] == kDfsBlack) 233 return true; 234 235 if (state_color_[s] == kDfsGrey) 236 return false; 237 238 state_color_[s] = kDfsGrey; 239 240 for (ArcIterator<Fst<A> > aiter(fst_, s); 241 !aiter.Done(); 242 aiter.Next()) { 243 const A &arc = aiter.Value(); 244 245 typename unordered_map<Label, Label>::const_iterator pit 246 = paren_id_map_.find(arc.ilabel); 247 if (pit != paren_id_map_.end()) { // paren? 248 Label paren_id = pit->second; 249 if (arc.ilabel == parens_[paren_id].first) { // open paren 250 if (!DFSearch(arc.nextstate)) 251 return false; 252 for (SetIterator set_iter = FindStates(paren_id, arc.nextstate); 253 !set_iter.Done(); set_iter.Next()) { 254 for (ParenArcIterator paren_arc_iter = 255 FindParenArcs(paren_id, set_iter.Element()); 256 !paren_arc_iter.Done(); 257 paren_arc_iter.Next()) { 258 const A &cparc = paren_arc_iter.Value(); 259 if (!DFSearch(cparc.nextstate)) 260 return false; 261 } 262 } 263 } 264 } else { // non-paren 265 if(!DFSearch(arc.nextstate)) 266 return false; 267 } 268 } 269 ComputeStateSet(s); 270 state_color_[s] = kDfsBlack; 271 return true; 272 } 273 274 // Unions state sets together gathered by the DFS. 275 template <class A> 276 void PdtParenReachable<A>::ComputeStateSet(StateId s) { 277 set<Label> paren_set; 278 vector< set<StateId> > state_sets(parens_.size()); 279 for (ArcIterator< Fst<A> > aiter(fst_, s); 280 !aiter.Done(); 281 aiter.Next()) { 282 const A &arc = aiter.Value(); 283 284 typename unordered_map<Label, Label>::const_iterator pit 285 = paren_id_map_.find(arc.ilabel); 286 if (pit != paren_id_map_.end()) { // paren? 287 Label paren_id = pit->second; 288 if (arc.ilabel == parens_[paren_id].first) { // open paren 289 for (SetIterator set_iter = 290 FindStates(paren_id, arc.nextstate); 291 !set_iter.Done(); set_iter.Next()) { 292 for (ParenArcIterator paren_arc_iter = 293 FindParenArcs(paren_id, set_iter.Element()); 294 !paren_arc_iter.Done(); 295 paren_arc_iter.Next()) { 296 const A &cparc = paren_arc_iter.Value(); 297 UpdateStateSet(cparc.nextstate, &paren_set, &state_sets); 298 } 299 } 300 } else { // close paren 301 paren_set.insert(paren_id); 302 state_sets[paren_id].insert(s); 303 ParenState<A> paren_state(paren_id, s); 304 paren_arc_multimap_.insert(make_pair(paren_state, arc)); 305 } 306 } else { // non-paren 307 UpdateStateSet(arc.nextstate, &paren_set, &state_sets); 308 } 309 } 310 311 vector<StateId> state_set; 312 for (typename set<Label>::iterator paren_iter = paren_set.begin(); 313 paren_iter != paren_set.end(); ++paren_iter) { 314 state_set.clear(); 315 Label paren_id = *paren_iter; 316 paren_multimap_.insert(make_pair(s, paren_id)); 317 for (typename set<StateId>::iterator state_iter 318 = state_sets[paren_id].begin(); 319 state_iter != state_sets[paren_id].end(); 320 ++state_iter) { 321 state_set.push_back(*state_iter); 322 } 323 ParenState<A> paren_state(paren_id, s); 324 set_map_[paren_state] = state_sets_.FindId(state_set); 325 } 326 } 327 328 // Gather state set(s) from state 'nexts'. 329 template <class A> 330 void PdtParenReachable<A>::UpdateStateSet( 331 StateId nexts, set<Label> *paren_set, 332 vector< set<StateId> > *state_sets) const { 333 for(ParenIterator paren_iter = FindParens(nexts); 334 !paren_iter.Done(); paren_iter.Next()) { 335 Label paren_id = paren_iter.Value(); 336 paren_set->insert(paren_id); 337 for (SetIterator set_iter = FindStates(paren_id, nexts); 338 !set_iter.Done(); set_iter.Next()) { 339 (*state_sets)[paren_id].insert(set_iter.Element()); 340 } 341 } 342 } 343 344 345 // Store balancing parenthesis data for a PDT. Allows on-the-fly 346 // construction (e.g. in PdtShortestPath) unlike PdtParenReachable above. 347 template <class A> 348 class PdtBalanceData { 349 public: 350 typedef typename A::StateId StateId; 351 typedef typename A::Label Label; 352 353 // Hash set for open parens 354 typedef unordered_set<ParenState<A>, typename ParenState<A>::Hash> OpenParenSet; 355 356 // Maps from open paren destination state to parenthesis ID. 357 typedef unordered_multimap<StateId, Label> OpenParenMap; 358 359 // Maps from open paren state to source states of matching close parens 360 typedef unordered_multimap<ParenState<A>, StateId, 361 typename ParenState<A>::Hash> CloseParenMap; 362 363 // Maps from open paren state to close source set ID 364 typedef unordered_map<ParenState<A>, ssize_t, 365 typename ParenState<A>::Hash> CloseSourceMap; 366 367 typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator; 368 369 PdtBalanceData() {} 370 371 void Clear() { 372 open_paren_map_.clear(); 373 close_paren_map_.clear(); 374 } 375 376 // Adds an open parenthesis with destination state 'open_dest'. 377 void OpenInsert(Label paren_id, StateId open_dest) { 378 ParenState<A> key(paren_id, open_dest); 379 if (!open_paren_set_.count(key)) { 380 open_paren_set_.insert(key); 381 open_paren_map_.insert(make_pair(open_dest, paren_id)); 382 } 383 } 384 385 // Adds a matching closing parenthesis with source state 386 // 'close_source' that balances an open_parenthesis with destination 387 // state 'open_dest' if OpenInsert() previously called 388 // (o.w. CloseInsert() does nothing). 389 void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) { 390 ParenState<A> key(paren_id, open_dest); 391 if (open_paren_set_.count(key)) 392 close_paren_map_.insert(make_pair(key, close_source)); 393 } 394 395 // Find close paren source states matching an open parenthesis. 396 // Methods that follow, iterate through those matching states. 397 // Should be called only after FinishInsert(open_dest). 398 SetIterator Find(Label paren_id, StateId open_dest) { 399 ParenState<A> close_key(paren_id, open_dest); 400 typename CloseSourceMap::const_iterator id_it = 401 close_source_map_.find(close_key); 402 if (id_it == close_source_map_.end()) { 403 return close_source_sets_.FindSet(-1); 404 } else { 405 return close_source_sets_.FindSet(id_it->second); 406 } 407 } 408 409 // Call when all open and close parenthesis insertions wrt open 410 // parentheses entering 'open_dest' are finished. Must be called 411 // before Find(open_dest). Stores close paren source state sets 412 // efficiently. 413 void FinishInsert(StateId open_dest) { 414 vector<StateId> close_sources; 415 for (typename OpenParenMap::iterator oit = open_paren_map_.find(open_dest); 416 oit != open_paren_map_.end() && oit->first == open_dest;) { 417 Label paren_id = oit->second; 418 close_sources.clear(); 419 ParenState<A> okey(paren_id, open_dest); 420 open_paren_set_.erase(open_paren_set_.find(okey)); 421 for (typename CloseParenMap::iterator cit = close_paren_map_.find(okey); 422 cit != close_paren_map_.end() && cit->first == okey;) { 423 close_sources.push_back(cit->second); 424 close_paren_map_.erase(cit++); 425 } 426 sort(close_sources.begin(), close_sources.end()); 427 typename vector<StateId>::iterator unique_end = 428 unique(close_sources.begin(), close_sources.end()); 429 close_sources.resize(unique_end - close_sources.begin()); 430 431 if (!close_sources.empty()) 432 close_source_map_[okey] = close_source_sets_.FindId(close_sources); 433 open_paren_map_.erase(oit++); 434 } 435 } 436 437 // Return a new balance data object representing the reversed balance 438 // information. 439 PdtBalanceData<A> *Reverse(StateId num_states, 440 StateId num_split, 441 StateId state_id_shift) const; 442 443 private: 444 OpenParenSet open_paren_set_; // open par. at dest? 445 446 OpenParenMap open_paren_map_; // open parens per state 447 ParenState<A> open_dest_; // cur open dest. state 448 typename OpenParenMap::const_iterator open_iter_; // cur open parens/state 449 450 CloseParenMap close_paren_map_; // close states/open 451 // paren and state 452 453 CloseSourceMap close_source_map_; // paren, state to set ID 454 mutable Collection<ssize_t, StateId> close_source_sets_; 455 }; 456 457 // Return a new balance data object representing the reversed balance 458 // information. 459 template <class A> 460 PdtBalanceData<A> *PdtBalanceData<A>::Reverse( 461 StateId num_states, 462 StateId num_split, 463 StateId state_id_shift) const { 464 PdtBalanceData<A> *bd = new PdtBalanceData<A>; 465 unordered_set<StateId> close_sources; 466 StateId split_size = num_states / num_split; 467 468 for (StateId i = 0; i < num_states; i+= split_size) { 469 close_sources.clear(); 470 471 for (typename CloseSourceMap::const_iterator 472 sit = close_source_map_.begin(); 473 sit != close_source_map_.end(); 474 ++sit) { 475 ParenState<A> okey = sit->first; 476 StateId open_dest = okey.state_id; 477 Label paren_id = okey.paren_id; 478 for (SetIterator set_iter = close_source_sets_.FindSet(sit->second); 479 !set_iter.Done(); set_iter.Next()) { 480 StateId close_source = set_iter.Element(); 481 if ((close_source < i) || (close_source >= i + split_size)) 482 continue; 483 close_sources.insert(close_source + state_id_shift); 484 bd->OpenInsert(paren_id, close_source + state_id_shift); 485 bd->CloseInsert(paren_id, close_source + state_id_shift, 486 open_dest + state_id_shift); 487 } 488 } 489 490 for (typename unordered_set<StateId>::const_iterator it 491 = close_sources.begin(); 492 it != close_sources.end(); 493 ++it) { 494 bd->FinishInsert(*it); 495 } 496 497 } 498 return bd; 499 } 500 501 502 } // namespace fst 503 504 #endif // FST_EXTENSIONS_PDT_PAREN_H_ 505